From 07340cdacad901a55abd0811a1fca86061b752bd Mon Sep 17 00:00:00 2001 From: Matthew Hodgson Date: Fri, 28 Sep 2018 01:42:53 +0100 Subject: untested stab at autocreating autojoin rooms --- synapse/config/registration.py | 4 ++++ synapse/handlers/register.py | 18 ++++++++++++++++++ 2 files changed, 22 insertions(+) diff --git a/synapse/config/registration.py b/synapse/config/registration.py index 0fb964eb67..686c7fa9f7 100644 --- a/synapse/config/registration.py +++ b/synapse/config/registration.py @@ -44,6 +44,7 @@ class RegistrationConfig(Config): ) self.auto_join_rooms = config.get("auto_join_rooms", []) + self.autocreate_auto_join_rooms = config.get("autocreate_auto_join_rooms", true) def default_config(self, **kwargs): registration_shared_secret = random_string_with_symbols(50) @@ -98,6 +99,9 @@ class RegistrationConfig(Config): # to these rooms #auto_join_rooms: # - "#example:example.com" + + # Have first user on server autocreate autojoin rooms + autocreate_auto_join_rooms: true """ % locals() def add_arguments(self, parser): diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index da914c46ff..0e5337d26c 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -50,6 +50,8 @@ class RegistrationHandler(BaseHandler): self._auth_handler = hs.get_auth_handler() self.profile_handler = hs.get_profile_handler() self.user_directory_handler = hs.get_user_directory_handler() + self._room_creation_handler = hs.get_room_creation_handler() + self._directory_handler = hs.get_handlers().directory_handler self.captcha_client = CaptchaServerHttpClient(hs) self._next_generated_user_id = None @@ -513,6 +515,22 @@ class RegistrationHandler(BaseHandler): @defer.inlineCallbacks def _join_user_to_room(self, requester, room_identifier): + + # try to create the room if we're the first user on the server + if self.config.autocreate_auto_join_rooms: + count = yield self.store.count_all_users() + if count == 1 and RoomAlias.is_valid(room_identifier): + info = yield self._room_creation_handler.create_room( + requester, + config={ + "preset": "public_chat", + }, + ratelimit=False, + ) + room_id = info["room_id"] + + yield create_association(self, requester.user, room_identifier, room_id) + room_id = None room_member_handler = self.hs.get_room_member_handler() if RoomID.is_valid(room_identifier): -- cgit 1.4.1 From 8f646f2d04e7d21aa0570826bd78a0bbc3bb706b Mon Sep 17 00:00:00 2001 From: Matthew Hodgson Date: Fri, 28 Sep 2018 15:37:28 +0100 Subject: fix UTs --- synapse/config/registration.py | 2 +- synapse/handlers/register.py | 13 +++++++++---- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/synapse/config/registration.py b/synapse/config/registration.py index 686c7fa9f7..dcf2374ed2 100644 --- a/synapse/config/registration.py +++ b/synapse/config/registration.py @@ -44,7 +44,7 @@ class RegistrationConfig(Config): ) self.auto_join_rooms = config.get("auto_join_rooms", []) - self.autocreate_auto_join_rooms = config.get("autocreate_auto_join_rooms", true) + self.autocreate_auto_join_rooms = config.get("autocreate_auto_join_rooms", True) def default_config(self, **kwargs): registration_shared_secret = random_string_with_symbols(50) diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 0e5337d26c..a358bfc723 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -50,8 +50,6 @@ class RegistrationHandler(BaseHandler): self._auth_handler = hs.get_auth_handler() self.profile_handler = hs.get_profile_handler() self.user_directory_handler = hs.get_user_directory_handler() - self._room_creation_handler = hs.get_room_creation_handler() - self._directory_handler = hs.get_handlers().directory_handler self.captcha_client = CaptchaServerHttpClient(hs) self._next_generated_user_id = None @@ -520,7 +518,8 @@ class RegistrationHandler(BaseHandler): if self.config.autocreate_auto_join_rooms: count = yield self.store.count_all_users() if count == 1 and RoomAlias.is_valid(room_identifier): - info = yield self._room_creation_handler.create_room( + room_creation_handler = hs.get_room_creation_handler() + info = yield room_creation_handler.create_room( requester, config={ "preset": "public_chat", @@ -529,7 +528,13 @@ class RegistrationHandler(BaseHandler): ) room_id = info["room_id"] - yield create_association(self, requester.user, room_identifier, room_id) + directory_handler = hs.get_handlers().directory_handler + yield directory_handler.create_association( + self, + requester.user, + room_identifier, + room_id + ) room_id = None room_member_handler = self.hs.get_room_member_handler() -- cgit 1.4.1 From 5b68f29f48acfa45e671af1fb325d0c0d532f3d6 Mon Sep 17 00:00:00 2001 From: Matthew Hodgson Date: Sat, 29 Sep 2018 02:14:40 +0100 Subject: fix thinkos --- synapse/handlers/register.py | 12 ++++++------ synapse/storage/directory.py | 2 +- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index a358bfc723..05e8f4ea73 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -515,10 +515,10 @@ class RegistrationHandler(BaseHandler): def _join_user_to_room(self, requester, room_identifier): # try to create the room if we're the first user on the server - if self.config.autocreate_auto_join_rooms: + if self.hs.config.autocreate_auto_join_rooms: count = yield self.store.count_all_users() if count == 1 and RoomAlias.is_valid(room_identifier): - room_creation_handler = hs.get_room_creation_handler() + room_creation_handler = self.hs.get_room_creation_handler() info = yield room_creation_handler.create_room( requester, config={ @@ -528,11 +528,11 @@ class RegistrationHandler(BaseHandler): ) room_id = info["room_id"] - directory_handler = hs.get_handlers().directory_handler + directory_handler = self.hs.get_handlers().directory_handler + room_alias = RoomAlias.from_string(room_identifier) yield directory_handler.create_association( - self, - requester.user, - room_identifier, + requester.user.to_string(), + room_alias, room_id ) diff --git a/synapse/storage/directory.py b/synapse/storage/directory.py index cfb687cb53..61a029a53c 100644 --- a/synapse/storage/directory.py +++ b/synapse/storage/directory.py @@ -90,7 +90,7 @@ class DirectoryWorkerStore(SQLBaseStore): class DirectoryStore(DirectoryWorkerStore): @defer.inlineCallbacks def create_room_alias_association(self, room_alias, room_id, servers, creator=None): - """ Creates an associatin between a room alias and room_id/servers + """ Creates an association between a room alias and room_id/servers Args: room_alias (RoomAlias) -- cgit 1.4.1 From 23b6a0537f9e2b46a5857a7c7bd0a03d04d4095d Mon Sep 17 00:00:00 2001 From: Matthew Hodgson Date: Sat, 29 Sep 2018 02:19:37 +0100 Subject: emit room aliases event --- synapse/handlers/register.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 05e8f4ea73..7584064d5d 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -531,9 +531,14 @@ class RegistrationHandler(BaseHandler): directory_handler = self.hs.get_handlers().directory_handler room_alias = RoomAlias.from_string(room_identifier) yield directory_handler.create_association( - requester.user.to_string(), - room_alias, - room_id + user_id=requester.user.to_string(), + room_alias=room_alias, + room_id=room_id, + servers=[self.hs.hostname], + ) + + yield directory_handler.send_room_alias_update_event( + requester, requester.user.to_string(), room_id ) room_id = None -- cgit 1.4.1 From faa462ef79529f6b1802764038fd7365dfc37385 Mon Sep 17 00:00:00 2001 From: Matthew Hodgson Date: Sat, 29 Sep 2018 02:21:01 +0100 Subject: changelog --- changelog.d/3975.feature | 1 + 1 file changed, 1 insertion(+) create mode 100644 changelog.d/3975.feature diff --git a/changelog.d/3975.feature b/changelog.d/3975.feature new file mode 100644 index 0000000000..5cd8ad6cca --- /dev/null +++ b/changelog.d/3975.feature @@ -0,0 +1 @@ +First user should autocreate autojoin rooms -- cgit 1.4.1 From 2dadc092b8f0e4cc1ac037ea54324efd906d4caf Mon Sep 17 00:00:00 2001 From: Neil Johnson Date: Thu, 4 Oct 2018 17:00:27 +0100 Subject: move logic into register, fix room alias localpart bug, tests --- synapse/handlers/register.py | 45 ++++++++++++------------------ tests/handlers/test_register.py | 62 ++++++++++++++++++++++++++++------------- tests/utils.py | 1 + 3 files changed, 62 insertions(+), 46 deletions(-) diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 7584064d5d..01cf7ab58e 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -220,8 +220,26 @@ class RegistrationHandler(BaseHandler): # auto-join the user to any rooms we're supposed to dump them into fake_requester = create_requester(user_id) + + # try to create the room if we're the first user on the server + if self.hs.config.autocreate_auto_join_rooms: + count = yield self.store.count_all_users() + auto_create_rooms = count == 1 + for r in self.hs.config.auto_join_rooms: try: + if auto_create_rooms and RoomAlias.is_valid(r): + room_creation_handler = self.hs.get_room_creation_handler() + # create room expects the localpart of the room alias + room_alias_localpart = RoomAlias.from_string(r).localpart + yield room_creation_handler.create_room( + fake_requester, + config={ + "preset": "public_chat", + "room_alias_name": room_alias_localpart + }, + ratelimit=False, + ) yield self._join_user_to_room(fake_requester, r) except Exception as e: logger.error("Failed to join new user to %r: %r", r, e) @@ -514,33 +532,6 @@ class RegistrationHandler(BaseHandler): @defer.inlineCallbacks def _join_user_to_room(self, requester, room_identifier): - # try to create the room if we're the first user on the server - if self.hs.config.autocreate_auto_join_rooms: - count = yield self.store.count_all_users() - if count == 1 and RoomAlias.is_valid(room_identifier): - room_creation_handler = self.hs.get_room_creation_handler() - info = yield room_creation_handler.create_room( - requester, - config={ - "preset": "public_chat", - }, - ratelimit=False, - ) - room_id = info["room_id"] - - directory_handler = self.hs.get_handlers().directory_handler - room_alias = RoomAlias.from_string(room_identifier) - yield directory_handler.create_association( - user_id=requester.user.to_string(), - room_alias=room_alias, - room_id=room_id, - servers=[self.hs.hostname], - ) - - yield directory_handler.send_room_alias_update_event( - requester, requester.user.to_string(), room_id - ) - room_id = None room_member_handler = self.hs.get_room_member_handler() if RoomID.is_valid(room_identifier): diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py index 7b4ade3dfb..a150a897ac 100644 --- a/tests/handlers/test_register.py +++ b/tests/handlers/test_register.py @@ -19,7 +19,7 @@ from twisted.internet import defer from synapse.api.errors import ResourceLimitError from synapse.handlers.register import RegistrationHandler -from synapse.types import UserID, create_requester +from synapse.types import RoomAlias, UserID, create_requester from tests.utils import setup_test_homeserver @@ -41,30 +41,28 @@ class RegistrationTestCase(unittest.TestCase): self.mock_captcha_client = Mock() self.hs = yield setup_test_homeserver( self.addCleanup, - handlers=None, - http_client=None, expire_access_token=True, - profile_handler=Mock(), ) self.macaroon_generator = Mock( generate_access_token=Mock(return_value='secret') ) self.hs.get_macaroon_generator = Mock(return_value=self.macaroon_generator) - self.hs.handlers = RegistrationHandlers(self.hs) + # self.hs.handlers = RegistrationHandlers(self.hs) self.handler = self.hs.get_handlers().registration_handler self.store = self.hs.get_datastore() self.hs.config.max_mau_value = 50 self.lots_of_users = 100 self.small_number_of_users = 1 + self.requester = create_requester("@requester:test") + @defer.inlineCallbacks def test_user_is_created_and_logged_in_if_doesnt_exist(self): - local_part = "someone" - display_name = "someone" - user_id = "@someone:test" - requester = create_requester("@as:test") + frank = UserID.from_string("@frank:test") + user_id = frank.to_string() + requester = create_requester(user_id) result_user_id, result_token = yield self.handler.get_or_create_user( - requester, local_part, display_name + requester, frank.localpart, "Frankie" ) self.assertEquals(result_user_id, user_id) self.assertEquals(result_token, 'secret') @@ -78,12 +76,11 @@ class RegistrationTestCase(unittest.TestCase): token="jkv;g498752-43gj['eamb!-5", password_hash=None, ) - local_part = "frank" - display_name = "Frank" - user_id = "@frank:test" - requester = create_requester("@as:test") + local_part = frank.localpart + user_id = frank.to_string() + requester = create_requester(user_id) result_user_id, result_token = yield self.handler.get_or_create_user( - requester, local_part, display_name + requester, local_part, None ) self.assertEquals(result_user_id, user_id) self.assertEquals(result_token, 'secret') @@ -92,7 +89,7 @@ class RegistrationTestCase(unittest.TestCase): def test_mau_limits_when_disabled(self): self.hs.config.limit_usage_by_mau = False # Ensure does not throw exception - yield self.handler.get_or_create_user("requester", 'a', "display_name") + yield self.handler.get_or_create_user(self.requester, 'a', "display_name") @defer.inlineCallbacks def test_get_or_create_user_mau_not_blocked(self): @@ -101,7 +98,7 @@ class RegistrationTestCase(unittest.TestCase): return_value=defer.succeed(self.hs.config.max_mau_value - 1) ) # Ensure does not throw exception - yield self.handler.get_or_create_user("@user:server", 'c', "User") + yield self.handler.get_or_create_user(self.requester, 'c', "User") @defer.inlineCallbacks def test_get_or_create_user_mau_blocked(self): @@ -110,13 +107,13 @@ class RegistrationTestCase(unittest.TestCase): return_value=defer.succeed(self.lots_of_users) ) with self.assertRaises(ResourceLimitError): - yield self.handler.get_or_create_user("requester", 'b', "display_name") + yield self.handler.get_or_create_user(self.requester, 'b', "display_name") self.store.get_monthly_active_count = Mock( return_value=defer.succeed(self.hs.config.max_mau_value) ) with self.assertRaises(ResourceLimitError): - yield self.handler.get_or_create_user("requester", 'b', "display_name") + yield self.handler.get_or_create_user(self.requester, 'b', "display_name") @defer.inlineCallbacks def test_register_mau_blocked(self): @@ -147,3 +144,30 @@ class RegistrationTestCase(unittest.TestCase): ) with self.assertRaises(ResourceLimitError): yield self.handler.register_saml2(localpart="local_part") + + @defer.inlineCallbacks + def test_auto_create_auto_join_rooms(self): + room_alias_str = "#room:test" + self.hs.config.autocreate_auto_join_rooms = True + self.hs.config.auto_join_rooms = [room_alias_str] + + res = yield self.handler.register(localpart='jeff') + rooms = yield self.store.get_rooms_for_user(res[0]) + + directory_handler = self.hs.get_handlers().directory_handler + room_alias = RoomAlias.from_string(room_alias_str) + room_id = yield directory_handler.get_association(room_alias) + + self.assertTrue(room_id['room_id'] in rooms) + self.assertEqual(len(rooms), 1) + + @defer.inlineCallbacks + def test_auto_create_auto_join_rooms_with_no_rooms(self): + self.hs.config.autocreate_auto_join_rooms = True + self.hs.config.auto_join_rooms = [] + frank = UserID.from_string("@frank:test") + res = yield self.handler.register(frank.localpart) + self.assertEqual(res[0], frank.to_string()) + rooms = yield self.store.get_rooms_for_user(res[0]) + + self.assertEqual(len(rooms), 0) diff --git a/tests/utils.py b/tests/utils.py index aaed1149c3..d34c224fb1 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -149,6 +149,7 @@ def setup_test_homeserver( config.block_events_without_consent_error = None config.media_storage_providers = [] config.auto_join_rooms = [] + config.autocreate_auto_join_rooms = True config.limit_usage_by_mau = False config.hs_disabled = False config.hs_disabled_message = "" -- cgit 1.4.1 From a2bfb778c8a2a91bdff1a5824571d91f4f4536d3 Mon Sep 17 00:00:00 2001 From: Neil Johnson Date: Fri, 12 Oct 2018 18:17:36 +0100 Subject: improve auto room join logic, comments and tests --- changelog.d/3975.feature | 2 +- synapse/config/registration.py | 11 ++++++++++- synapse/handlers/register.py | 11 ++++++++--- tests/handlers/test_register.py | 21 +++++++++++++++++---- 4 files changed, 36 insertions(+), 9 deletions(-) diff --git a/changelog.d/3975.feature b/changelog.d/3975.feature index 5cd8ad6cca..496ba4f4a0 100644 --- a/changelog.d/3975.feature +++ b/changelog.d/3975.feature @@ -1 +1 @@ -First user should autocreate autojoin rooms +Servers with auto join rooms, should autocreate those rooms when first user registers diff --git a/synapse/config/registration.py b/synapse/config/registration.py index dcf2374ed2..43ff20a637 100644 --- a/synapse/config/registration.py +++ b/synapse/config/registration.py @@ -15,6 +15,8 @@ from distutils.util import strtobool +from synapse.config._base import ConfigError +from synapse.types import RoomAlias from synapse.util.stringutils import random_string_with_symbols from ._base import Config @@ -44,6 +46,9 @@ class RegistrationConfig(Config): ) self.auto_join_rooms = config.get("auto_join_rooms", []) + for room_alias in self.auto_join_rooms: + if not RoomAlias.is_valid(room_alias): + raise ConfigError('Invalid auto_join_rooms entry %s' % room_alias) self.autocreate_auto_join_rooms = config.get("autocreate_auto_join_rooms", True) def default_config(self, **kwargs): @@ -100,7 +105,11 @@ class RegistrationConfig(Config): #auto_join_rooms: # - "#example:example.com" - # Have first user on server autocreate autojoin rooms + # Where auto_join_rooms are specified, setting this flag ensures that the + # the rooms exists by creating them when the first user on the + # homeserver registers. + # Setting to false means that if the rooms are not manually created, + # users cannot be auto joined since they do not exist. autocreate_auto_join_rooms: true """ % locals() diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 01cf7ab58e..2b269ab692 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -26,6 +26,7 @@ from synapse.api.errors import ( RegistrationError, SynapseError, ) +from synapse.config._base import ConfigError from synapse.http.client import CaptchaServerHttpClient from synapse.types import RoomAlias, RoomID, UserID, create_requester from synapse.util.async_helpers import Linearizer @@ -222,14 +223,19 @@ class RegistrationHandler(BaseHandler): fake_requester = create_requester(user_id) # try to create the room if we're the first user on the server + should_auto_create_rooms = False if self.hs.config.autocreate_auto_join_rooms: count = yield self.store.count_all_users() - auto_create_rooms = count == 1 + should_auto_create_rooms = count == 1 for r in self.hs.config.auto_join_rooms: try: - if auto_create_rooms and RoomAlias.is_valid(r): + if should_auto_create_rooms: room_creation_handler = self.hs.get_room_creation_handler() + if self.hs.hostname != RoomAlias.from_string(r).domain: + raise ConfigError( + 'Cannot create room alias %s, it does not match server domain' + ) # create room expects the localpart of the room alias room_alias_localpart = RoomAlias.from_string(r).localpart yield room_creation_handler.create_room( @@ -531,7 +537,6 @@ class RegistrationHandler(BaseHandler): @defer.inlineCallbacks def _join_user_to_room(self, requester, room_identifier): - room_id = None room_member_handler = self.hs.get_room_member_handler() if RoomID.is_valid(room_identifier): diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py index a150a897ac..3e9a190727 100644 --- a/tests/handlers/test_register.py +++ b/tests/handlers/test_register.py @@ -47,7 +47,6 @@ class RegistrationTestCase(unittest.TestCase): generate_access_token=Mock(return_value='secret') ) self.hs.get_macaroon_generator = Mock(return_value=self.macaroon_generator) - # self.hs.handlers = RegistrationHandlers(self.hs) self.handler = self.hs.get_handlers().registration_handler self.store = self.hs.get_datastore() self.hs.config.max_mau_value = 50 @@ -148,9 +147,7 @@ class RegistrationTestCase(unittest.TestCase): @defer.inlineCallbacks def test_auto_create_auto_join_rooms(self): room_alias_str = "#room:test" - self.hs.config.autocreate_auto_join_rooms = True self.hs.config.auto_join_rooms = [room_alias_str] - res = yield self.handler.register(localpart='jeff') rooms = yield self.store.get_rooms_for_user(res[0]) @@ -163,11 +160,27 @@ class RegistrationTestCase(unittest.TestCase): @defer.inlineCallbacks def test_auto_create_auto_join_rooms_with_no_rooms(self): - self.hs.config.autocreate_auto_join_rooms = True self.hs.config.auto_join_rooms = [] frank = UserID.from_string("@frank:test") res = yield self.handler.register(frank.localpart) self.assertEqual(res[0], frank.to_string()) rooms = yield self.store.get_rooms_for_user(res[0]) + self.assertEqual(len(rooms), 0) + + @defer.inlineCallbacks + def test_auto_create_auto_join_where_room_is_another_domain(self): + self.hs.config.auto_join_rooms = ["#room:another"] + frank = UserID.from_string("@frank:test") + res = yield self.handler.register(frank.localpart) + self.assertEqual(res[0], frank.to_string()) + rooms = yield self.store.get_rooms_for_user(res[0]) + self.assertEqual(len(rooms), 0) + @defer.inlineCallbacks + def test_auto_create_auto_join_where_auto_create_is_false(self): + self.hs.config.autocreate_auto_join_rooms = False + room_alias_str = "#room:test" + self.hs.config.auto_join_rooms = [room_alias_str] + res = yield self.handler.register(localpart='jeff') + rooms = yield self.store.get_rooms_for_user(res[0]) self.assertEqual(len(rooms), 0) -- cgit 1.4.1 From 1ccafb0c5e3e6c9c64ce231163e7c73f05b910fa Mon Sep 17 00:00:00 2001 From: Neil Johnson Date: Sat, 13 Oct 2018 21:14:21 +0100 Subject: no need to join room if creator --- synapse/handlers/register.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 2b269ab692..2f7bdb0a20 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -246,7 +246,8 @@ class RegistrationHandler(BaseHandler): }, ratelimit=False, ) - yield self._join_user_to_room(fake_requester, r) + else: + yield self._join_user_to_room(fake_requester, r) except Exception as e: logger.error("Failed to join new user to %r: %r", r, e) -- cgit 1.4.1 From c6584f4b5f9cc495478e03e01f85fd2399cf6f8d Mon Sep 17 00:00:00 2001 From: Neil Johnson Date: Wed, 17 Oct 2018 11:36:41 +0100 Subject: clean up config error logic and imports --- changelog.d/3975.feature | 2 +- synapse/config/registration.py | 9 ++++----- synapse/handlers/register.py | 30 ++++++++++++++++-------------- 3 files changed, 21 insertions(+), 20 deletions(-) diff --git a/changelog.d/3975.feature b/changelog.d/3975.feature index 496ba4f4a0..79c2711fb4 100644 --- a/changelog.d/3975.feature +++ b/changelog.d/3975.feature @@ -1 +1 @@ -Servers with auto join rooms, should autocreate those rooms when first user registers +Servers with auto-join rooms, should automatically create those rooms when first user registers diff --git a/synapse/config/registration.py b/synapse/config/registration.py index 43ff20a637..4b9bf6f2d1 100644 --- a/synapse/config/registration.py +++ b/synapse/config/registration.py @@ -15,11 +15,10 @@ from distutils.util import strtobool -from synapse.config._base import ConfigError from synapse.types import RoomAlias from synapse.util.stringutils import random_string_with_symbols -from ._base import Config +from ._base import Config, ConfigError class RegistrationConfig(Config): @@ -48,7 +47,7 @@ class RegistrationConfig(Config): self.auto_join_rooms = config.get("auto_join_rooms", []) for room_alias in self.auto_join_rooms: if not RoomAlias.is_valid(room_alias): - raise ConfigError('Invalid auto_join_rooms entry %s' % room_alias) + raise ConfigError('Invalid auto_join_rooms entry %s' % (room_alias,)) self.autocreate_auto_join_rooms = config.get("autocreate_auto_join_rooms", True) def default_config(self, **kwargs): @@ -106,10 +105,10 @@ class RegistrationConfig(Config): # - "#example:example.com" # Where auto_join_rooms are specified, setting this flag ensures that the - # the rooms exists by creating them when the first user on the + # the rooms exist by creating them when the first user on the # homeserver registers. # Setting to false means that if the rooms are not manually created, - # users cannot be auto joined since they do not exist. + # users cannot be auto-joined since they do not exist. autocreate_auto_join_rooms: true """ % locals() diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 2f7bdb0a20..1b5873c8d7 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -26,7 +26,6 @@ from synapse.api.errors import ( RegistrationError, SynapseError, ) -from synapse.config._base import ConfigError from synapse.http.client import CaptchaServerHttpClient from synapse.types import RoomAlias, RoomID, UserID, create_requester from synapse.util.async_helpers import Linearizer @@ -51,6 +50,7 @@ class RegistrationHandler(BaseHandler): self._auth_handler = hs.get_auth_handler() self.profile_handler = hs.get_profile_handler() self.user_directory_handler = hs.get_user_directory_handler() + self.room_creation_handler = self.hs.get_room_creation_handler() self.captcha_client = CaptchaServerHttpClient(hs) self._next_generated_user_id = None @@ -231,21 +231,23 @@ class RegistrationHandler(BaseHandler): for r in self.hs.config.auto_join_rooms: try: if should_auto_create_rooms: - room_creation_handler = self.hs.get_room_creation_handler() if self.hs.hostname != RoomAlias.from_string(r).domain: - raise ConfigError( - 'Cannot create room alias %s, it does not match server domain' + logger.warn( + 'Cannot create room alias %s, ' + 'it does not match server domain' % (r,) + ) + raise SynapseError() + else: + # create room expects the localpart of the room alias + room_alias_localpart = RoomAlias.from_string(r).localpart + yield self.room_creation_handler.create_room( + fake_requester, + config={ + "preset": "public_chat", + "room_alias_name": room_alias_localpart + }, + ratelimit=False, ) - # create room expects the localpart of the room alias - room_alias_localpart = RoomAlias.from_string(r).localpart - yield room_creation_handler.create_room( - fake_requester, - config={ - "preset": "public_chat", - "room_alias_name": room_alias_localpart - }, - ratelimit=False, - ) else: yield self._join_user_to_room(fake_requester, r) except Exception as e: -- cgit 1.4.1 From 084046456ec88588779a62f9378c1a8e911bfc7c Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 17 Oct 2018 16:14:04 +0100 Subject: Add config option to control alias creation --- synapse/config/homeserver.py | 3 +- synapse/config/room_directory.py | 101 ++++++++++++++++++++++++++++++++ synapse/federation/federation_server.py | 16 +---- synapse/handlers/directory.py | 9 +++ synapse/util/__init__.py | 21 +++++++ 5 files changed, 135 insertions(+), 15 deletions(-) create mode 100644 synapse/config/room_directory.py diff --git a/synapse/config/homeserver.py b/synapse/config/homeserver.py index b8d5690f2b..10dd40159f 100644 --- a/synapse/config/homeserver.py +++ b/synapse/config/homeserver.py @@ -31,6 +31,7 @@ from .push import PushConfig from .ratelimiting import RatelimitConfig from .registration import RegistrationConfig from .repository import ContentRepositoryConfig +from .room_directory import RoomDirectoryConfig from .saml2 import SAML2Config from .server import ServerConfig from .server_notices_config import ServerNoticesConfig @@ -49,7 +50,7 @@ class HomeServerConfig(TlsConfig, ServerConfig, DatabaseConfig, LoggingConfig, WorkerConfig, PasswordAuthProviderConfig, PushConfig, SpamCheckerConfig, GroupsConfig, UserDirectoryConfig, ConsentConfig, - ServerNoticesConfig, + ServerNoticesConfig, RoomDirectoryConfig, ): pass diff --git a/synapse/config/room_directory.py b/synapse/config/room_directory.py new file mode 100644 index 0000000000..41ef3217e8 --- /dev/null +++ b/synapse/config/room_directory.py @@ -0,0 +1,101 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.util import glob_to_regex + +from ._base import Config, ConfigError + + +class RoomDirectoryConfig(Config): + def read_config(self, config): + alias_creation_rules = config["alias_creation_rules"] + + self._alias_creation_rules = [ + _AliasRule(rule) + for rule in alias_creation_rules + ] + + def default_config(self, config_dir_path, server_name, **kwargs): + return """ + # The `alias_creation` option controls who's allowed to create aliases + # on this server. + # + # The format of this option is a list of rules that contain globs that + # match against user_id and the new alias (fully qualified with server + # name). The action in the first rule that matches is taken, which can + # currently either be "allowed" or "denied". + # + # If no rules match the request is denied. + alias_creation_rules: + - user_id: "*" + alias: "*" + action: allowed + """ + + def is_alias_creation_allowed(self, user_id, alias): + """Checks if the given user is allowed to create the given alias + + Args: + user_id (str) + alias (str) + + Returns: + boolean: True if user is allowed to crate the alias + """ + for rule in self._alias_creation_rules: + if rule.matches(user_id, alias): + return rule.action == "allowed" + + return False + + +class _AliasRule(object): + def __init__(self, rule): + action = rule["action"] + user_id = rule["user_id"] + alias = rule["alias"] + + if action in ("allowed", "denied"): + self.action = action + else: + raise ConfigError( + "alias_creation_rules rules can only have action of 'allowed'" + " or 'denied'" + ) + + try: + self._user_id_regex = glob_to_regex(user_id) + self._alias_regex = glob_to_regex(alias) + except Exception as e: + raise ConfigError("Failed to parse glob into regex: %s", e) + + def matches(self, user_id, alias): + """Tests if this rule matches the given user_id and alias. + + Args: + user_id (str) + alias (str) + + Returns: + boolean + """ + + if not self._user_id_regex.search(user_id): + return False + + if not self._alias_regex.search(alias): + return False + + return True diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index 4efe95faa4..d041c26824 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -14,7 +14,6 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -import re import six from six import iteritems @@ -44,6 +43,7 @@ from synapse.replication.http.federation import ( ReplicationGetQueryRestServlet, ) from synapse.types import get_domain_from_id +from synapse.util import glob_to_regex from synapse.util.async_helpers import Linearizer, concurrently_execute from synapse.util.caches.response_cache import ResponseCache from synapse.util.logcontext import nested_logging_context @@ -729,22 +729,10 @@ def _acl_entry_matches(server_name, acl_entry): if not isinstance(acl_entry, six.string_types): logger.warn("Ignoring non-str ACL entry '%s' (is %s)", acl_entry, type(acl_entry)) return False - regex = _glob_to_regex(acl_entry) + regex = glob_to_regex(acl_entry) return regex.match(server_name) -def _glob_to_regex(glob): - res = '' - for c in glob: - if c == '*': - res = res + '.*' - elif c == '?': - res = res + '.' - else: - res = res + re.escape(c) - return re.compile(res + "\\Z", re.IGNORECASE) - - class FederationHandlerRegistry(object): """Allows classes to register themselves as handlers for a given EDU or query type for incoming federation traffic. diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py index 02f12f6645..7d67bf803a 100644 --- a/synapse/handlers/directory.py +++ b/synapse/handlers/directory.py @@ -43,6 +43,7 @@ class DirectoryHandler(BaseHandler): self.state = hs.get_state_handler() self.appservice_handler = hs.get_application_service_handler() self.event_creation_handler = hs.get_event_creation_handler() + self.config = hs.config self.federation = hs.get_federation_client() hs.get_federation_registry().register_query_handler( @@ -111,6 +112,14 @@ class DirectoryHandler(BaseHandler): 403, "This user is not permitted to create this alias", ) + if not self.config.is_alias_creation_allowed(user_id, room_alias.to_string()): + # Lets just return a generic message, as there may be all sorts of + # reasons why we said no. TODO: Allow configurable error messages + # per alias creation rule? + raise SynapseError( + 403, "Not allowed to create alias", + ) + can_create = yield self.can_modify_alias( room_alias, user_id=user_id diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py index 9a8fae0497..163e4b35ff 100644 --- a/synapse/util/__init__.py +++ b/synapse/util/__init__.py @@ -14,6 +14,7 @@ # limitations under the License. import logging +import re from itertools import islice import attr @@ -138,3 +139,23 @@ def log_failure(failure, msg, consumeErrors=True): if not consumeErrors: return failure + + +def glob_to_regex(glob): + """Converts a glob to a compiled regex object + + Args: + glob (str) + + Returns: + re.RegexObject + """ + res = '' + for c in glob: + if c == '*': + res = res + '.*' + elif c == '?': + res = res + '.' + else: + res = res + re.escape(c) + return re.compile(res + "\\Z", re.IGNORECASE) -- cgit 1.4.1 From f9d6c677eac35c926339a904b1f0c8c9dbd9049a Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 17 Oct 2018 16:36:39 +0100 Subject: Newsfile --- changelog.d/4051.feature | 1 + 1 file changed, 1 insertion(+) create mode 100644 changelog.d/4051.feature diff --git a/changelog.d/4051.feature b/changelog.d/4051.feature new file mode 100644 index 0000000000..9c1b3a72a0 --- /dev/null +++ b/changelog.d/4051.feature @@ -0,0 +1 @@ +Add config option to control alias creation -- cgit 1.4.1 From 9fafdfa97d87006177d13d4b80aeebfc4ded4bee Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 18 Oct 2018 14:21:09 +0100 Subject: Anchor returned regex to start and end of string --- synapse/util/__init__.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py index 163e4b35ff..0ae7e2ef3b 100644 --- a/synapse/util/__init__.py +++ b/synapse/util/__init__.py @@ -142,7 +142,9 @@ def log_failure(failure, msg, consumeErrors=True): def glob_to_regex(glob): - """Converts a glob to a compiled regex object + """Converts a glob to a compiled regex object. + + The regex is anchored at the beginning and end of the string. Args: glob (str) @@ -158,4 +160,6 @@ def glob_to_regex(glob): res = res + '.' else: res = res + re.escape(c) - return re.compile(res + "\\Z", re.IGNORECASE) + + # \A anchors at start of string, \Z at end of string + return re.compile(r"\A" + res + r"\Z", re.IGNORECASE) -- cgit 1.4.1 From 1b4bf232b9fd348a94b8bc4e9c851ed5b6d8e801 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 18 Oct 2018 16:04:50 +0100 Subject: Add tests for config generation --- tests/config/test_room_directory.py | 67 +++++++++++++++++++++++++++++++++++++ 1 file changed, 67 insertions(+) create mode 100644 tests/config/test_room_directory.py diff --git a/tests/config/test_room_directory.py b/tests/config/test_room_directory.py new file mode 100644 index 0000000000..75021a5f04 --- /dev/null +++ b/tests/config/test_room_directory.py @@ -0,0 +1,67 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import yaml + +from synapse.config.room_directory import RoomDirectoryConfig + +from tests import unittest + + +class RoomDirectoryConfigTestCase(unittest.TestCase): + def test_alias_creation_acl(self): + config = yaml.load(""" + alias_creation_rules: + - user_id: "*bob*" + alias: "*" + action: "denied" + - user_id: "*" + alias: "#unofficial_*" + action: "allowed" + - user_id: "@foo*:example.com" + alias: "*" + action: "allowed" + - user_id: "@gah:example.com" + alias: "#goo:example.com" + action: "allowed" + """) + + rd_config = RoomDirectoryConfig() + rd_config.read_config(config) + + self.assertFalse(rd_config.is_alias_creation_allowed( + user_id="@bob:example.com", + alias="#test:example.com", + )) + + self.assertTrue(rd_config.is_alias_creation_allowed( + user_id="@test:example.com", + alias="#unofficial_st:example.com", + )) + + self.assertTrue(rd_config.is_alias_creation_allowed( + user_id="@foobar:example.com", + alias="#test:example.com", + )) + + self.assertTrue(rd_config.is_alias_creation_allowed( + user_id="@gah:example.com", + alias="#goo:example.com", + )) + + self.assertFalse(rd_config.is_alias_creation_allowed( + user_id="@test:example.com", + alias="#test:example.com", + )) -- cgit 1.4.1 From 3c580c2b47fab5f85c76a7061065e11b7d0eaeb8 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 18 Oct 2018 16:14:41 +0100 Subject: Add tests for alias creation rules --- tests/handlers/test_directory.py | 48 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 48 insertions(+) diff --git a/tests/handlers/test_directory.py b/tests/handlers/test_directory.py index ec7355688b..4f299b74ba 100644 --- a/tests/handlers/test_directory.py +++ b/tests/handlers/test_directory.py @@ -18,7 +18,9 @@ from mock import Mock from twisted.internet import defer +from synapse.config.room_directory import RoomDirectoryConfig from synapse.handlers.directory import DirectoryHandler +from synapse.rest.client.v1 import directory, room from synapse.types import RoomAlias from tests import unittest @@ -102,3 +104,49 @@ class DirectoryTestCase(unittest.TestCase): ) self.assertEquals({"room_id": "!8765asdf:test", "servers": ["test"]}, response) + + +class TestCreateAliasACL(unittest.HomeserverTestCase): + user_id = "@test:test" + + servlets = [directory.register_servlets, room.register_servlets] + + def prepare(self, hs, reactor, clock): + # We cheekily override the config to add custom alias creation rules + config = {} + config["alias_creation_rules"] = [ + { + "user_id": "*", + "alias": "#unofficial_*", + "action": "allowed", + } + ] + + rd_config = RoomDirectoryConfig() + rd_config.read_config(config) + + self.hs.config.is_alias_creation_allowed = rd_config.is_alias_creation_allowed + + return hs + + def test_denied(self): + room_id = self.helper.create_room_as(self.user_id) + + request, channel = self.make_request( + "PUT", + b"directory/room/%23test%3Atest", + ('{"room_id":"%s"}' % (room_id,)).encode('ascii'), + ) + self.render(request) + self.assertEquals(403, channel.code, channel.result) + + def test_allowed(self): + room_id = self.helper.create_room_as(self.user_id) + + request, channel = self.make_request( + "PUT", + b"directory/room/%23unofficial_test%3Atest", + ('{"room_id":"%s"}' % (room_id,)).encode('ascii'), + ) + self.render(request) + self.assertEquals(200, channel.code, channel.result) -- cgit 1.4.1 From eba48c0f16d08d5806da73d0b51d62de6aaee118 Mon Sep 17 00:00:00 2001 From: steamport Date: Fri, 19 Oct 2018 19:58:28 +0000 Subject: Add Caddy example to README --- README.rst | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/README.rst b/README.rst index e1ea351f84..de133f674d 100644 --- a/README.rst +++ b/README.rst @@ -652,6 +652,7 @@ Using a reverse proxy with Synapse It is recommended to put a reverse proxy such as `nginx `_, `Apache `_ or +`Caddy `_ or `HAProxy `_ in front of Synapse. One advantage of doing so is that it means that you can expose the default https port (443) to Matrix clients without needing to run Synapse with root privileges. @@ -682,6 +683,11 @@ so an example nginx configuration might look like:: } } +an example caddy configuration may look like:: + proxy /_matrix http://localhost:8008 { + transparent + } + and an example apache configuration may look like:: -- cgit 1.4.1 From b85fe45f46abd7ff1c42851ab6dbb2f4ff33758f Mon Sep 17 00:00:00 2001 From: steamport Date: Fri, 19 Oct 2018 21:55:38 +0000 Subject: Add CL --- changelog.d/4072.misc | 1 + 1 file changed, 1 insertion(+) create mode 100644 changelog.d/4072.misc diff --git a/changelog.d/4072.misc b/changelog.d/4072.misc new file mode 100644 index 0000000000..9d7279fd2b --- /dev/null +++ b/changelog.d/4072.misc @@ -0,0 +1 @@ +The README now contains example for the Caddy web server. Contributed by steamp0rt. -- cgit 1.4.1 From 08760b0d9ac7d1ecf48b6aaef662cf346da7368d Mon Sep 17 00:00:00 2001 From: steamport Date: Fri, 19 Oct 2018 21:57:28 +0000 Subject: Fix formatting. --- README.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/README.rst b/README.rst index de133f674d..bb99f4fe2b 100644 --- a/README.rst +++ b/README.rst @@ -684,6 +684,7 @@ so an example nginx configuration might look like:: } an example caddy configuration may look like:: + proxy /_matrix http://localhost:8008 { transparent } -- cgit 1.4.1 From 9c2f99a3b74ebd84c14314344bf7ff785b8ff1d1 Mon Sep 17 00:00:00 2001 From: steamport Date: Fri, 19 Oct 2018 21:59:14 +0000 Subject: Fix --- README.rst | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/README.rst b/README.rst index bb99f4fe2b..b42a96a34c 100644 --- a/README.rst +++ b/README.rst @@ -684,9 +684,10 @@ so an example nginx configuration might look like:: } an example caddy configuration may look like:: - - proxy /_matrix http://localhost:8008 { - transparent + example.com { + proxy /_matrix http://localhost:8008 { + transparent + } } and an example apache configuration may look like:: -- cgit 1.4.1 From 3f357583ce17e0816acbc373bc2179293212ba4e Mon Sep 17 00:00:00 2001 From: steamport Date: Fri, 19 Oct 2018 21:59:39 +0000 Subject: I HATE RST --- README.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/README.rst b/README.rst index b42a96a34c..df9430f4cc 100644 --- a/README.rst +++ b/README.rst @@ -684,6 +684,7 @@ so an example nginx configuration might look like:: } an example caddy configuration may look like:: + example.com { proxy /_matrix http://localhost:8008 { transparent -- cgit 1.4.1 From 5c3d6ea9c72125ad152823e73f52667e845e6a61 Mon Sep 17 00:00:00 2001 From: steamport Date: Fri, 19 Oct 2018 22:00:27 +0000 Subject: Whoops! --- README.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.rst b/README.rst index df9430f4cc..209313ba37 100644 --- a/README.rst +++ b/README.rst @@ -685,7 +685,7 @@ so an example nginx configuration might look like:: an example caddy configuration may look like:: - example.com { + matrix.example.com { proxy /_matrix http://localhost:8008 { transparent } -- cgit 1.4.1 From 6340141300ea9370651b9dbcc2574fff96f45d48 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Mon, 22 Oct 2018 16:17:27 +0100 Subject: README.rst: fix minor grammar --- README.rst | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/README.rst b/README.rst index 209313ba37..9f27c14c4a 100644 --- a/README.rst +++ b/README.rst @@ -651,7 +651,7 @@ Using a reverse proxy with Synapse It is recommended to put a reverse proxy such as `nginx `_, -`Apache `_ or +`Apache `_, `Caddy `_ or `HAProxy `_ in front of Synapse. One advantage of doing so is that it means that you can expose the default https port (443) to @@ -683,7 +683,7 @@ so an example nginx configuration might look like:: } } -an example caddy configuration may look like:: +an example Caddy configuration might look like:: matrix.example.com { proxy /_matrix http://localhost:8008 { @@ -691,7 +691,7 @@ an example caddy configuration may look like:: } } -and an example apache configuration may look like:: +and an example Apache configuration might look like:: SSLEngine on -- cgit 1.4.1 From 6105c6101fa0f4560daf7d23dedcfd217530b02a Mon Sep 17 00:00:00 2001 From: Neil Johnson Date: Tue, 23 Oct 2018 15:24:58 +0100 Subject: fix race condiftion in calling initialise_reserved_users --- changelog.d/4081.bugfix | 2 ++ synapse/app/homeserver.py | 8 ------ synapse/storage/monthly_active_users.py | 46 +++++++++++++++++++++--------- synapse/storage/registration.py | 16 ++++++++--- tests/storage/test_monthly_active_users.py | 10 +++++-- 5 files changed, 55 insertions(+), 27 deletions(-) create mode 100644 changelog.d/4081.bugfix diff --git a/changelog.d/4081.bugfix b/changelog.d/4081.bugfix new file mode 100644 index 0000000000..f275acb61c --- /dev/null +++ b/changelog.d/4081.bugfix @@ -0,0 +1,2 @@ +Fix race condition in populating reserved users + diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index 0b85b377e3..593e1e75db 100755 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -553,14 +553,6 @@ def run(hs): generate_monthly_active_users, ) - # XXX is this really supposed to be a background process? it looks - # like it needs to complete before some of the other stuff runs. - run_as_background_process( - "initialise_reserved_users", - hs.get_datastore().initialise_reserved_users, - hs.config.mau_limits_reserved_threepids, - ) - start_generate_monthly_active_users() if hs.config.limit_usage_by_mau: clock.looping_call(start_generate_monthly_active_users, 5 * 60 * 1000) diff --git a/synapse/storage/monthly_active_users.py b/synapse/storage/monthly_active_users.py index 0fe8c8e24c..26e577814a 100644 --- a/synapse/storage/monthly_active_users.py +++ b/synapse/storage/monthly_active_users.py @@ -33,19 +33,28 @@ class MonthlyActiveUsersStore(SQLBaseStore): self._clock = hs.get_clock() self.hs = hs self.reserved_users = () + self.initialise_reserved_users( + dbconn.cursor(), hs.config.mau_limits_reserved_threepids + ) - @defer.inlineCallbacks - def initialise_reserved_users(self, threepids): - store = self.hs.get_datastore() + def initialise_reserved_users(self, txn, threepids): + """ + Ensures that reserved threepids are accounted for in the MAU table, should + be called on start up. + + Arguments: + threepids []: List of threepid dicts to reserve + """ reserved_user_list = [] # Do not add more reserved users than the total allowable number for tp in threepids[:self.hs.config.max_mau_value]: - user_id = yield store.get_user_id_by_threepid( + user_id = self.get_user_id_by_threepid_txn( + txn, tp["medium"], tp["address"] ) if user_id: - yield self.upsert_monthly_active_user(user_id) + self.upsert_monthly_active_user_txn(txn, user_id) reserved_user_list.append(user_id) else: logger.warning( @@ -55,8 +64,7 @@ class MonthlyActiveUsersStore(SQLBaseStore): @defer.inlineCallbacks def reap_monthly_active_users(self): - """ - Cleans out monthly active user table to ensure that no stale + """Cleans out monthly active user table to ensure that no stale entries exist. Returns: @@ -165,19 +173,33 @@ class MonthlyActiveUsersStore(SQLBaseStore): @defer.inlineCallbacks def upsert_monthly_active_user(self, user_id): + """Updates or inserts monthly active user member + Arguments: + user_id (str): user to add/update + """ + is_insert = yield self.runInteraction( + "upsert_monthly_active_user", self.upsert_monthly_active_user_txn, + user_id + ) + if is_insert: + self.user_last_seen_monthly_active.invalidate((user_id,)) + self.get_monthly_active_count.invalidate(()) + + def upsert_monthly_active_user_txn(self, txn, user_id): """ Updates or inserts monthly active user member Arguments: + txn (cursor): user_id (str): user to add/update - Deferred[bool]: True if a new entry was created, False if an + bool: True if a new entry was created, False if an existing one was updated. """ # Am consciously deciding to lock the table on the basis that is ought # never be a big table and alternative approaches (batching multiple # upserts into a single txn) introduced a lot of extra complexity. # See https://github.com/matrix-org/synapse/issues/3854 for more - is_insert = yield self._simple_upsert( - desc="upsert_monthly_active_user", + is_insert = self._simple_upsert_txn( + txn, table="monthly_active_users", keyvalues={ "user_id": user_id, @@ -186,9 +208,7 @@ class MonthlyActiveUsersStore(SQLBaseStore): "timestamp": int(self._clock.time_msec()), }, ) - if is_insert: - self.user_last_seen_monthly_active.invalidate((user_id,)) - self.get_monthly_active_count.invalidate(()) + return is_insert @cached(num_args=1) def user_last_seen_monthly_active(self, user_id): diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py index 26b429e307..01931f29cc 100644 --- a/synapse/storage/registration.py +++ b/synapse/storage/registration.py @@ -474,17 +474,25 @@ class RegistrationStore(RegistrationWorkerStore, @defer.inlineCallbacks def get_user_id_by_threepid(self, medium, address): - ret = yield self._simple_select_one( + user_id = yield self.runInteraction( + "get_user_id_by_threepid", self.get_user_id_by_threepid_txn, + medium, address + ) + defer.returnValue(user_id) + + def get_user_id_by_threepid_txn(self, txn, medium, address): + ret = self._simple_select_one_txn( + txn, "user_threepids", { "medium": medium, "address": address }, - ['user_id'], True, 'get_user_id_by_threepid' + ['user_id'], True ) if ret: - defer.returnValue(ret['user_id']) - defer.returnValue(None) + return ret['user_id'] + return None def user_delete_threepid(self, user_id, medium, address): return self._simple_delete( diff --git a/tests/storage/test_monthly_active_users.py b/tests/storage/test_monthly_active_users.py index 686f12a0dc..0c17745ae7 100644 --- a/tests/storage/test_monthly_active_users.py +++ b/tests/storage/test_monthly_active_users.py @@ -52,7 +52,10 @@ class MonthlyActiveUsersTestCase(HomeserverTestCase): now = int(self.hs.get_clock().time_msec()) self.store.user_add_threepid(user1, "email", user1_email, now, now) self.store.user_add_threepid(user2, "email", user2_email, now, now) - self.store.initialise_reserved_users(threepids) + + self.store.runInteraction( + "initialise", self.store.initialise_reserved_users, threepids + ) self.pump() active_count = self.store.get_monthly_active_count() @@ -199,7 +202,10 @@ class MonthlyActiveUsersTestCase(HomeserverTestCase): {'medium': 'email', 'address': user2_email}, ] self.hs.config.mau_limits_reserved_threepids = threepids - self.store.initialise_reserved_users(threepids) + self.store.runInteraction( + "initialise", self.store.initialise_reserved_users, threepids + ) + self.pump() count = self.store.get_registered_reserved_users_count() self.assertEquals(self.get_success(count), 0) -- cgit 1.4.1 From 329d18b39cbc1bc9eba8ae9de1fcf734d0cf1a78 Mon Sep 17 00:00:00 2001 From: Neil Johnson Date: Tue, 23 Oct 2018 15:27:20 +0100 Subject: remove white space --- changelog.d/4081.bugfix | 1 - synapse/storage/monthly_active_users.py | 3 +-- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/changelog.d/4081.bugfix b/changelog.d/4081.bugfix index f275acb61c..13dad58842 100644 --- a/changelog.d/4081.bugfix +++ b/changelog.d/4081.bugfix @@ -1,2 +1 @@ Fix race condition in populating reserved users - diff --git a/synapse/storage/monthly_active_users.py b/synapse/storage/monthly_active_users.py index 26e577814a..cf15f8c5ba 100644 --- a/synapse/storage/monthly_active_users.py +++ b/synapse/storage/monthly_active_users.py @@ -38,8 +38,7 @@ class MonthlyActiveUsersStore(SQLBaseStore): ) def initialise_reserved_users(self, txn, threepids): - """ - Ensures that reserved threepids are accounted for in the MAU table, should + """Ensures that reserved threepids are accounted for in the MAU table, should be called on start up. Arguments: -- cgit 1.4.1 From a67d8ace9bc8b5f5ab953fdcfd6ade077337782d Mon Sep 17 00:00:00 2001 From: Neil Johnson Date: Tue, 23 Oct 2018 17:44:39 +0100 Subject: remove errant exception and style --- changelog.d/3975.feature | 2 +- synapse/config/registration.py | 2 +- synapse/handlers/register.py | 10 +++++----- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/changelog.d/3975.feature b/changelog.d/3975.feature index 79c2711fb4..1e33f1b3b8 100644 --- a/changelog.d/3975.feature +++ b/changelog.d/3975.feature @@ -1 +1 @@ -Servers with auto-join rooms, should automatically create those rooms when first user registers +Servers with auto-join rooms, will now automatically create those rooms when the first user registers diff --git a/synapse/config/registration.py b/synapse/config/registration.py index 4b9bf6f2d1..5df321b287 100644 --- a/synapse/config/registration.py +++ b/synapse/config/registration.py @@ -18,7 +18,7 @@ from distutils.util import strtobool from synapse.types import RoomAlias from synapse.util.stringutils import random_string_with_symbols -from ._base import Config, ConfigError +from synapse.config._base import Config, ConfigError class RegistrationConfig(Config): diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 1b5873c8d7..9615dd552f 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -231,15 +231,15 @@ class RegistrationHandler(BaseHandler): for r in self.hs.config.auto_join_rooms: try: if should_auto_create_rooms: - if self.hs.hostname != RoomAlias.from_string(r).domain: - logger.warn( + room_alias = RoomAlias.from_string(r) + if self.hs.hostname != room_alias.domain: + logger.warning( 'Cannot create room alias %s, ' - 'it does not match server domain' % (r,) + 'it does not match server domain', (r,) ) - raise SynapseError() else: # create room expects the localpart of the room alias - room_alias_localpart = RoomAlias.from_string(r).localpart + room_alias_localpart = room_alias.localpart yield self.room_creation_handler.create_room( fake_requester, config={ -- cgit 1.4.1 From 47a9ba435d9a4b9e311d9d9a3c02be105942f357 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 19 Oct 2018 10:26:50 +0100 Subject: Use match rather than search --- synapse/config/room_directory.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/synapse/config/room_directory.py b/synapse/config/room_directory.py index 41ef3217e8..2ca010afdb 100644 --- a/synapse/config/room_directory.py +++ b/synapse/config/room_directory.py @@ -92,10 +92,11 @@ class _AliasRule(object): boolean """ - if not self._user_id_regex.search(user_id): + # Note: The regexes are anchored at both ends + if not self._user_id_regex.match(user_id): return False - if not self._alias_regex.search(alias): + if not self._alias_regex.match(alias): return False return True -- cgit 1.4.1 From 9f72c209eed4bec78133ca9821961eb9a3cf0dad Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Wed, 24 Oct 2018 14:37:36 +0100 Subject: Update changelog.d/3975.feature Co-Authored-By: neilisfragile --- changelog.d/3975.feature | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/changelog.d/3975.feature b/changelog.d/3975.feature index 1e33f1b3b8..162f30a532 100644 --- a/changelog.d/3975.feature +++ b/changelog.d/3975.feature @@ -1 +1 @@ -Servers with auto-join rooms, will now automatically create those rooms when the first user registers +Servers with auto-join rooms will now automatically create those rooms when the first user registers -- cgit 1.4.1 From 94a49e0636313507fab2e43b99f969385f535ea2 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Wed, 24 Oct 2018 14:39:23 +0100 Subject: fix tuple Co-Authored-By: neilisfragile --- synapse/handlers/register.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 9615dd552f..217928add1 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -235,7 +235,8 @@ class RegistrationHandler(BaseHandler): if self.hs.hostname != room_alias.domain: logger.warning( 'Cannot create room alias %s, ' - 'it does not match server domain', (r,) + 'it does not match server domain', + r, ) else: # create room expects the localpart of the room alias -- cgit 1.4.1 From ab96ee29c96d10638d6ae69a80521aa8c4f19113 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Wed, 24 Oct 2018 13:41:31 +0100 Subject: reduce git clone depth --- .travis.yml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/.travis.yml b/.travis.yml index 197dec2bc9..5b6fc47cee 100644 --- a/.travis.yml +++ b/.travis.yml @@ -4,6 +4,10 @@ language: python # tell travis to cache ~/.cache/pip cache: pip +# don't clone the whole repo history, one commit will do +git: + depth: 1 + # only build branches we care about (PRs are built seperately) branches: only: -- cgit 1.4.1 From 480d98c91f32da5127b695047b912b889d0b9dc2 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Wed, 24 Oct 2018 14:49:25 +0100 Subject: Disable newsfragment checks on branch builds --- .travis.yml | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/.travis.yml b/.travis.yml index 5b6fc47cee..107b3ab70b 100644 --- a/.travis.yml +++ b/.travis.yml @@ -53,7 +53,9 @@ matrix: - python: 3.6 env: TOX_ENV=check_isort - - python: 3.6 + - # we only need to check for the newsfragment if it's a PR build + if: type = pull_request + python: 3.6 env: TOX_ENV=check-newsfragment install: -- cgit 1.4.1 From 83d9ca71225239e2f4ba27dc79907e93d8573d1b Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Wed, 24 Oct 2018 13:48:03 +0100 Subject: only fetch develop for check-newsfragments --- .travis.yml | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/.travis.yml b/.travis.yml index 107b3ab70b..56e10dbdc6 100644 --- a/.travis.yml +++ b/.travis.yml @@ -15,10 +15,6 @@ branches: - develop - /^release-v/ -before_script: - - git remote set-branches --add origin develop - - git fetch origin develop - matrix: fast_finish: true include: @@ -57,6 +53,10 @@ matrix: if: type = pull_request python: 3.6 env: TOX_ENV=check-newsfragment + script: + - git remote set-branches --add origin develop + - git fetch origin develop + - tox -e $TOX_ENV install: - pip install tox -- cgit 1.4.1 From 9532caf6ef3cdc6dba80b11a920410b50d3490dd Mon Sep 17 00:00:00 2001 From: Neil Johnson Date: Wed, 24 Oct 2018 16:08:25 +0100 Subject: remove trailing whiter space --- synapse/handlers/register.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 217928add1..e9d7b25a36 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -235,7 +235,7 @@ class RegistrationHandler(BaseHandler): if self.hs.hostname != room_alias.domain: logger.warning( 'Cannot create room alias %s, ' - 'it does not match server domain', + 'it does not match server domain', r, ) else: -- cgit 1.4.1 From 9ec218658650558a704e939e56f87d1df1a84423 Mon Sep 17 00:00:00 2001 From: Neil Johnson Date: Wed, 24 Oct 2018 16:09:21 +0100 Subject: isort --- synapse/config/registration.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/synapse/config/registration.py b/synapse/config/registration.py index 5df321b287..7480ed5145 100644 --- a/synapse/config/registration.py +++ b/synapse/config/registration.py @@ -15,11 +15,10 @@ from distutils.util import strtobool +from synapse.config._base import Config, ConfigError from synapse.types import RoomAlias from synapse.util.stringutils import random_string_with_symbols -from synapse.config._base import Config, ConfigError - class RegistrationConfig(Config): -- cgit 1.4.1 From 663d9db8e7892b79a927de70ec9add64312827d5 Mon Sep 17 00:00:00 2001 From: Neil Johnson Date: Wed, 24 Oct 2018 17:17:30 +0100 Subject: commit transaction before closing --- synapse/server.py | 1 + 1 file changed, 1 insertion(+) diff --git a/synapse/server.py b/synapse/server.py index 3e9d3d8256..cf6b872cbd 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -207,6 +207,7 @@ class HomeServer(object): logger.info("Setting up.") with self.get_db_conn() as conn: self.datastore = self.DATASTORE_CLASS(conn, self) + conn.commit() logger.info("Finished setting up.") def get_reactor(self): -- cgit 1.4.1 From ea69a84bbb2fc9c1da6db0384a98f577f3bc95a7 Mon Sep 17 00:00:00 2001 From: Neil Johnson Date: Wed, 24 Oct 2018 17:18:08 +0100 Subject: fix style inconsistencies --- changelog.d/4081.bugfix | 3 ++- synapse/storage/monthly_active_users.py | 43 +++++++++++++++++++----------- synapse/storage/registration.py | 19 +++++++++++++ tests/storage/test_monthly_active_users.py | 4 +-- 4 files changed, 51 insertions(+), 18 deletions(-) diff --git a/changelog.d/4081.bugfix b/changelog.d/4081.bugfix index 13dad58842..cfe4b3e9d9 100644 --- a/changelog.d/4081.bugfix +++ b/changelog.d/4081.bugfix @@ -1 +1,2 @@ -Fix race condition in populating reserved users +Fix race condition where config defined reserved users were not being added to +the monthly active user list prior to the homeserver reactor firing up diff --git a/synapse/storage/monthly_active_users.py b/synapse/storage/monthly_active_users.py index cf15f8c5ba..9a5c3b7eda 100644 --- a/synapse/storage/monthly_active_users.py +++ b/synapse/storage/monthly_active_users.py @@ -33,21 +33,23 @@ class MonthlyActiveUsersStore(SQLBaseStore): self._clock = hs.get_clock() self.hs = hs self.reserved_users = () - self.initialise_reserved_users( - dbconn.cursor(), hs.config.mau_limits_reserved_threepids + # Do not add more reserved users than the total allowable number + self._initialise_reserved_users( + dbconn.cursor(), + hs.config.mau_limits_reserved_threepids[:self.hs.config.max_mau_value], ) - def initialise_reserved_users(self, txn, threepids): + def _initialise_reserved_users(self, txn, threepids): """Ensures that reserved threepids are accounted for in the MAU table, should be called on start up. - Arguments: - threepids []: List of threepid dicts to reserve + Args: + txn (cursor): + threepids (list[dict]): List of threepid dicts to reserve """ reserved_user_list = [] - # Do not add more reserved users than the total allowable number - for tp in threepids[:self.hs.config.max_mau_value]: + for tp in threepids: user_id = self.get_user_id_by_threepid_txn( txn, tp["medium"], tp["address"] @@ -172,26 +174,36 @@ class MonthlyActiveUsersStore(SQLBaseStore): @defer.inlineCallbacks def upsert_monthly_active_user(self, user_id): - """Updates or inserts monthly active user member - Arguments: + """Updates or inserts the user into the monthly active user table, which + is used to track the current MAU usage of the server + + Args: user_id (str): user to add/update """ is_insert = yield self.runInteraction( "upsert_monthly_active_user", self.upsert_monthly_active_user_txn, user_id ) + # Considered pushing cache invalidation down into txn method, but + # did not because txn is not a LoggingTransaction. This means I could not + # call txn.call_after(). Therefore cache is altered in background thread + # and calls from elsewhere to user_last_seen_monthly_active and + # get_monthly_active_count fail with ValueError in + # synapse/util/caches/descriptors.py#check_thread if is_insert: self.user_last_seen_monthly_active.invalidate((user_id,)) self.get_monthly_active_count.invalidate(()) def upsert_monthly_active_user_txn(self, txn, user_id): - """ - Updates or inserts monthly active user member - Arguments: - txn (cursor): - user_id (str): user to add/update + """Updates or inserts monthly active user member + + Args: + txn (cursor): + user_id (str): user to add/update + + Returns: bool: True if a new entry was created, False if an - existing one was updated. + existing one was updated. """ # Am consciously deciding to lock the table on the basis that is ought # never be a big table and alternative approaches (batching multiple @@ -207,6 +219,7 @@ class MonthlyActiveUsersStore(SQLBaseStore): "timestamp": int(self._clock.time_msec()), }, ) + return is_insert @cached(num_args=1) diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py index 0f970850e8..80d76bf9d7 100644 --- a/synapse/storage/registration.py +++ b/synapse/storage/registration.py @@ -474,6 +474,15 @@ class RegistrationStore(RegistrationWorkerStore, @defer.inlineCallbacks def get_user_id_by_threepid(self, medium, address): + """Returns user id from threepid + + Args: + medium (str): threepid medium e.g. email + address (str): threepid address e.g. me@example.com + + Returns: + Deferred[str|None]: user id or None if no user id/threepid mapping exists + """ user_id = yield self.runInteraction( "get_user_id_by_threepid", self.get_user_id_by_threepid_txn, medium, address @@ -481,6 +490,16 @@ class RegistrationStore(RegistrationWorkerStore, defer.returnValue(user_id) def get_user_id_by_threepid_txn(self, txn, medium, address): + """Returns user id from threepid + + Args: + txn (cursor): + medium (str): threepid medium e.g. email + address (str): threepid address e.g. me@example.com + + Returns: + str|None: user id or None if no user id/threepid mapping exists + """ ret = self._simple_select_one_txn( txn, "user_threepids", diff --git a/tests/storage/test_monthly_active_users.py b/tests/storage/test_monthly_active_users.py index 0c17745ae7..832e379a83 100644 --- a/tests/storage/test_monthly_active_users.py +++ b/tests/storage/test_monthly_active_users.py @@ -54,7 +54,7 @@ class MonthlyActiveUsersTestCase(HomeserverTestCase): self.store.user_add_threepid(user2, "email", user2_email, now, now) self.store.runInteraction( - "initialise", self.store.initialise_reserved_users, threepids + "initialise", self.store._initialise_reserved_users, threepids ) self.pump() @@ -203,7 +203,7 @@ class MonthlyActiveUsersTestCase(HomeserverTestCase): ] self.hs.config.mau_limits_reserved_threepids = threepids self.store.runInteraction( - "initialise", self.store.initialise_reserved_users, threepids + "initialise", self.store._initialise_reserved_users, threepids ) self.pump() -- cgit 1.4.1 From 77d3b5772fc53fdc4461ebb3e6bd2c6c8f7e78bf Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Wed, 24 Oct 2018 21:59:26 +0100 Subject: disable coverage checking I don't think we ever use this, and it slows things down. If we want to use it, we should just do so on a couple of builds rather than all of them. --- tox.ini | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/tox.ini b/tox.ini index 04d2f721bf..9de5a5704a 100644 --- a/tox.ini +++ b/tox.ini @@ -3,7 +3,6 @@ envlist = packaging, py27, py36, pep8, check_isort [base] deps = - coverage Twisted>=17.1 mock python-subunit @@ -26,9 +25,7 @@ passenv = * commands = /usr/bin/find "{toxinidir}" -name '*.pyc' -delete - coverage run {env:COVERAGE_OPTS:} --source="{toxinidir}/synapse" \ - "{envbindir}/trial" {env:TRIAL_FLAGS:} {posargs:tests} {env:TOXSUFFIX:} - {env:DUMP_COVERAGE_COMMAND:coverage report -m} + "{envbindir}/trial" {env:TRIAL_FLAGS:} {posargs:tests} {env:TOXSUFFIX:} [testenv:py27] -- cgit 1.4.1 From fc33e813231a9a78012080193a4a2a96951515da Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Thu, 25 Oct 2018 00:44:55 +0100 Subject: Combine the pep8 and check_isort builds into one there's really no point spinning up two separate jobs for these. --- .travis.yml | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/.travis.yml b/.travis.yml index 56e10dbdc6..64079b56f6 100644 --- a/.travis.yml +++ b/.travis.yml @@ -22,7 +22,7 @@ matrix: env: TOX_ENV=packaging - python: 3.6 - env: TOX_ENV=pep8 + env: TOX_ENV="pep8, check_isort" - python: 2.7 env: TOX_ENV=py27 @@ -46,9 +46,6 @@ matrix: services: - postgresql - - python: 3.6 - env: TOX_ENV=check_isort - - # we only need to check for the newsfragment if it's a PR build if: type = pull_request python: 3.6 -- cgit 1.4.1 From 46f98a6a297ff014e7e88fb056e59755fd18cf58 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Thu, 25 Oct 2018 01:00:58 +0100 Subject: Only cache the wheels --- .travis.yml | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/.travis.yml b/.travis.yml index 64079b56f6..ac34c36725 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,8 +1,16 @@ sudo: false language: python -# tell travis to cache ~/.cache/pip -cache: pip +cache: + directories: + # we only bother to cache the wheels; parts of the http cache get + # invalidated every build (because they get served with a max-age of 600 + # seconds), which means that we end up re-uploading the whole cache for + # every build, which is time-consuming In any case, it's not obvious that + # downloading the cache from S3 would be much faster than downloading the + # originals from pypi. + # + - $HOME/.cache/pip/wheels # don't clone the whole repo history, one commit will do git: -- cgit 1.4.1 From edd2d828095c5e9352d2755af19c0be5e4dc9e0d Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Thu, 25 Oct 2018 01:06:39 +0100 Subject: oops, run the check_isort build --- .travis.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.travis.yml b/.travis.yml index ac34c36725..fd41841c77 100644 --- a/.travis.yml +++ b/.travis.yml @@ -30,7 +30,7 @@ matrix: env: TOX_ENV=packaging - python: 3.6 - env: TOX_ENV="pep8, check_isort" + env: TOX_ENV="pep8,check_isort" - python: 2.7 env: TOX_ENV=py27 -- cgit 1.4.1 From f8fe98812be74e37432f11508c07488079bf949d Mon Sep 17 00:00:00 2001 From: Neil Johnson Date: Thu, 25 Oct 2018 14:58:59 +0100 Subject: improve comments --- synapse/storage/monthly_active_users.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/synapse/storage/monthly_active_users.py b/synapse/storage/monthly_active_users.py index 9a5c3b7eda..01963c879f 100644 --- a/synapse/storage/monthly_active_users.py +++ b/synapse/storage/monthly_active_users.py @@ -184,18 +184,18 @@ class MonthlyActiveUsersStore(SQLBaseStore): "upsert_monthly_active_user", self.upsert_monthly_active_user_txn, user_id ) - # Considered pushing cache invalidation down into txn method, but - # did not because txn is not a LoggingTransaction. This means I could not - # call txn.call_after(). Therefore cache is altered in background thread - # and calls from elsewhere to user_last_seen_monthly_active and - # get_monthly_active_count fail with ValueError in - # synapse/util/caches/descriptors.py#check_thread + if is_insert: self.user_last_seen_monthly_active.invalidate((user_id,)) self.get_monthly_active_count.invalidate(()) def upsert_monthly_active_user_txn(self, txn, user_id): """Updates or inserts monthly active user member + Note that, after calling this method, it will generally be necessary + to invalidate the caches on user_last_seen_monthly_active and + get_monthly_active_count. We can't do that here, because we are running + in a database thread rather than the main thread, and we can't call + txn.call_after because txn may not be a LoggingTransaction. Args: txn (cursor): -- cgit 1.4.1 From e5481b22aac800b016e05f50a50ced85a226b364 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 25 Oct 2018 15:25:21 +0100 Subject: Use allow/deny --- synapse/config/room_directory.py | 12 ++++++------ tests/config/test_room_directory.py | 8 ++++---- tests/handlers/test_directory.py | 2 +- 3 files changed, 11 insertions(+), 11 deletions(-) diff --git a/synapse/config/room_directory.py b/synapse/config/room_directory.py index 2ca010afdb..9da13ab11b 100644 --- a/synapse/config/room_directory.py +++ b/synapse/config/room_directory.py @@ -35,13 +35,13 @@ class RoomDirectoryConfig(Config): # The format of this option is a list of rules that contain globs that # match against user_id and the new alias (fully qualified with server # name). The action in the first rule that matches is taken, which can - # currently either be "allowed" or "denied". + # currently either be "allow" or "deny". # # If no rules match the request is denied. alias_creation_rules: - user_id: "*" alias: "*" - action: allowed + action: allow """ def is_alias_creation_allowed(self, user_id, alias): @@ -56,7 +56,7 @@ class RoomDirectoryConfig(Config): """ for rule in self._alias_creation_rules: if rule.matches(user_id, alias): - return rule.action == "allowed" + return rule.action == "allow" return False @@ -67,12 +67,12 @@ class _AliasRule(object): user_id = rule["user_id"] alias = rule["alias"] - if action in ("allowed", "denied"): + if action in ("allow", "deny"): self.action = action else: raise ConfigError( - "alias_creation_rules rules can only have action of 'allowed'" - " or 'denied'" + "alias_creation_rules rules can only have action of 'allow'" + " or 'deny'" ) try: diff --git a/tests/config/test_room_directory.py b/tests/config/test_room_directory.py index 75021a5f04..f37a17d618 100644 --- a/tests/config/test_room_directory.py +++ b/tests/config/test_room_directory.py @@ -26,16 +26,16 @@ class RoomDirectoryConfigTestCase(unittest.TestCase): alias_creation_rules: - user_id: "*bob*" alias: "*" - action: "denied" + action: "deny" - user_id: "*" alias: "#unofficial_*" - action: "allowed" + action: "allow" - user_id: "@foo*:example.com" alias: "*" - action: "allowed" + action: "allow" - user_id: "@gah:example.com" alias: "#goo:example.com" - action: "allowed" + action: "allow" """) rd_config = RoomDirectoryConfig() diff --git a/tests/handlers/test_directory.py b/tests/handlers/test_directory.py index 4f299b74ba..8ae6556c0a 100644 --- a/tests/handlers/test_directory.py +++ b/tests/handlers/test_directory.py @@ -118,7 +118,7 @@ class TestCreateAliasACL(unittest.HomeserverTestCase): { "user_id": "*", "alias": "#unofficial_*", - "action": "allowed", + "action": "allow", } ] -- cgit 1.4.1 From fcbd488e9a60a70a81cd7a94a6adebe6e06c525a Mon Sep 17 00:00:00 2001 From: Neil Johnson Date: Thu, 25 Oct 2018 16:13:43 +0100 Subject: add new line --- synapse/storage/monthly_active_users.py | 1 + 1 file changed, 1 insertion(+) diff --git a/synapse/storage/monthly_active_users.py b/synapse/storage/monthly_active_users.py index 01963c879f..cf4104dc2e 100644 --- a/synapse/storage/monthly_active_users.py +++ b/synapse/storage/monthly_active_users.py @@ -191,6 +191,7 @@ class MonthlyActiveUsersStore(SQLBaseStore): def upsert_monthly_active_user_txn(self, txn, user_id): """Updates or inserts monthly active user member + Note that, after calling this method, it will generally be necessary to invalidate the caches on user_last_seen_monthly_active and get_monthly_active_count. We can't do that here, because we are running -- cgit 1.4.1 From cb53ce9d6429252d5ee012f5a476cc834251c27d Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 25 Oct 2018 17:49:55 +0100 Subject: Refactor state group lookup to reduce DB hits (#4011) Currently when fetching state groups from the data store we make two hits two the database: once for members and once for non-members (unless request is filtered to one or the other). This adds needless load to the datbase, so this PR refactors the lookup to make only a single database hit. --- changelog.d/4011.misc | 1 + synapse/handlers/initial_sync.py | 4 +- synapse/handlers/message.py | 20 +- synapse/handlers/pagination.py | 15 +- synapse/handlers/room.py | 22 +- synapse/handlers/sync.py | 97 ++--- synapse/rest/client/v1/room.py | 3 +- synapse/storage/events.py | 2 +- synapse/storage/state.py | 845 ++++++++++++++++++++++++--------------- synapse/visibility.py | 15 +- tests/storage/test_state.py | 175 +++++--- 11 files changed, 714 insertions(+), 485 deletions(-) create mode 100644 changelog.d/4011.misc diff --git a/changelog.d/4011.misc b/changelog.d/4011.misc new file mode 100644 index 0000000000..ad7768c4cd --- /dev/null +++ b/changelog.d/4011.misc @@ -0,0 +1 @@ +Reduce database load when fetching state groups diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py index e009395207..563bb3cea3 100644 --- a/synapse/handlers/initial_sync.py +++ b/synapse/handlers/initial_sync.py @@ -156,7 +156,7 @@ class InitialSyncHandler(BaseHandler): room_end_token = "s%d" % (event.stream_ordering,) deferred_room_state = run_in_background( self.store.get_state_for_events, - [event.event_id], None, + [event.event_id], ) deferred_room_state.addCallback( lambda states: states[event.event_id] @@ -301,7 +301,7 @@ class InitialSyncHandler(BaseHandler): def _room_initial_sync_parted(self, user_id, room_id, pagin_config, membership, member_event_id, is_peeking): room_state = yield self.store.get_state_for_events( - [member_event_id], None + [member_event_id], ) room_state = room_state[member_event_id] diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 6c4fcfb10a..969e588e73 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -35,6 +35,7 @@ from synapse.crypto.event_signing import add_hashes_and_signatures from synapse.events.utils import serialize_event from synapse.events.validator import EventValidator from synapse.replication.http.send_event import ReplicationSendEventRestServlet +from synapse.storage.state import StateFilter from synapse.types import RoomAlias, UserID from synapse.util.async_helpers import Linearizer from synapse.util.frozenutils import frozendict_json_encoder @@ -80,7 +81,7 @@ class MessageHandler(object): elif membership == Membership.LEAVE: key = (event_type, state_key) room_state = yield self.store.get_state_for_events( - [membership_event_id], [key] + [membership_event_id], StateFilter.from_types([key]) ) data = room_state[membership_event_id].get(key) @@ -88,7 +89,7 @@ class MessageHandler(object): @defer.inlineCallbacks def get_state_events( - self, user_id, room_id, types=None, filtered_types=None, + self, user_id, room_id, state_filter=StateFilter.all(), at_token=None, is_guest=False, ): """Retrieve all state events for a given room. If the user is @@ -100,13 +101,8 @@ class MessageHandler(object): Args: user_id(str): The user requesting state events. room_id(str): The room ID to get all state events from. - types(list[(str, str|None)]|None): List of (type, state_key) tuples - which are used to filter the state fetched. If `state_key` is None, - all events are returned of the given type. - May be None, which matches any key. - filtered_types(list[str]|None): Only apply filtering via `types` to this - list of event types. Other types of events are returned unfiltered. - If None, `types` filtering is applied to all events. + state_filter (StateFilter): The state filter used to fetch state + from the database. at_token(StreamToken|None): the stream token of the at which we are requesting the stats. If the user is not allowed to view the state as of that stream token, we raise a 403 SynapseError. If None, returns the current @@ -139,7 +135,7 @@ class MessageHandler(object): event = last_events[0] if visible_events: room_state = yield self.store.get_state_for_events( - [event.event_id], types, filtered_types=filtered_types, + [event.event_id], state_filter=state_filter, ) room_state = room_state[event.event_id] else: @@ -158,12 +154,12 @@ class MessageHandler(object): if membership == Membership.JOIN: state_ids = yield self.store.get_filtered_current_state_ids( - room_id, types, filtered_types=filtered_types, + room_id, state_filter=state_filter, ) room_state = yield self.store.get_events(state_ids.values()) elif membership == Membership.LEAVE: room_state = yield self.store.get_state_for_events( - [membership_event_id], types, filtered_types=filtered_types, + [membership_event_id], state_filter=state_filter, ) room_state = room_state[membership_event_id] diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py index a155b6e938..43f81bd607 100644 --- a/synapse/handlers/pagination.py +++ b/synapse/handlers/pagination.py @@ -21,6 +21,7 @@ from twisted.python.failure import Failure from synapse.api.constants import EventTypes, Membership from synapse.api.errors import SynapseError from synapse.events.utils import serialize_event +from synapse.storage.state import StateFilter from synapse.types import RoomStreamToken from synapse.util.async_helpers import ReadWriteLock from synapse.util.logcontext import run_in_background @@ -255,16 +256,14 @@ class PaginationHandler(object): if event_filter and event_filter.lazy_load_members(): # TODO: remove redundant members - types = [ - (EventTypes.Member, state_key) - for state_key in set( - event.sender # FIXME: we also care about invite targets etc. - for event in events - ) - ] + # FIXME: we also care about invite targets etc. + state_filter = StateFilter.from_types( + (EventTypes.Member, event.sender) + for event in events + ) state_ids = yield self.store.get_state_ids_for_event( - events[0].event_id, types=types, + events[0].event_id, state_filter=state_filter, ) if state_ids: diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index ab1571b27b..3ba92bdb4c 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -33,6 +33,7 @@ from synapse.api.constants import ( RoomCreationPreset, ) from synapse.api.errors import AuthError, Codes, StoreError, SynapseError +from synapse.storage.state import StateFilter from synapse.types import RoomAlias, RoomID, RoomStreamToken, StreamToken, UserID from synapse.util import stringutils from synapse.visibility import filter_events_for_client @@ -489,23 +490,24 @@ class RoomContextHandler(object): else: last_event_id = event_id - types = None - filtered_types = None if event_filter and event_filter.lazy_load_members(): - members = set(ev.sender for ev in itertools.chain( - results["events_before"], - (results["event"],), - results["events_after"], - )) - filtered_types = [EventTypes.Member] - types = [(EventTypes.Member, member) for member in members] + state_filter = StateFilter.from_lazy_load_member_list( + ev.sender + for ev in itertools.chain( + results["events_before"], + (results["event"],), + results["events_after"], + ) + ) + else: + state_filter = StateFilter.all() # XXX: why do we return the state as of the last event rather than the # first? Shouldn't we be consistent with /sync? # https://github.com/matrix-org/matrix-doc/issues/687 state = yield self.store.get_state_for_events( - [last_event_id], types, filtered_types=filtered_types, + [last_event_id], state_filter=state_filter, ) results["state"] = list(state[last_event_id].values()) diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 351892a94f..09739f2862 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -27,6 +27,7 @@ from twisted.internet import defer from synapse.api.constants import EventTypes, Membership from synapse.push.clientformat import format_push_rules_for_user from synapse.storage.roommember import MemberSummary +from synapse.storage.state import StateFilter from synapse.types import RoomStreamToken from synapse.util.async_helpers import concurrently_execute from synapse.util.caches.expiringcache import ExpiringCache @@ -469,25 +470,20 @@ class SyncHandler(object): )) @defer.inlineCallbacks - def get_state_after_event(self, event, types=None, filtered_types=None): + def get_state_after_event(self, event, state_filter=StateFilter.all()): """ Get the room state after the given event Args: event(synapse.events.EventBase): event of interest - types(list[(str, str|None)]|None): List of (type, state_key) tuples - which are used to filter the state fetched. If `state_key` is None, - all events are returned of the given type. - May be None, which matches any key. - filtered_types(list[str]|None): Only apply filtering via `types` to this - list of event types. Other types of events are returned unfiltered. - If None, `types` filtering is applied to all events. + state_filter (StateFilter): The state filter used to fetch state + from the database. Returns: A Deferred map from ((type, state_key)->Event) """ state_ids = yield self.store.get_state_ids_for_event( - event.event_id, types, filtered_types=filtered_types, + event.event_id, state_filter=state_filter, ) if event.is_state(): state_ids = state_ids.copy() @@ -495,18 +491,14 @@ class SyncHandler(object): defer.returnValue(state_ids) @defer.inlineCallbacks - def get_state_at(self, room_id, stream_position, types=None, filtered_types=None): + def get_state_at(self, room_id, stream_position, state_filter=StateFilter.all()): """ Get the room state at a particular stream position Args: room_id(str): room for which to get state stream_position(StreamToken): point at which to get state - types(list[(str, str|None)]|None): List of (type, state_key) tuples - which are used to filter the state fetched. If `state_key` is None, - all events are returned of the given type. - filtered_types(list[str]|None): Only apply filtering via `types` to this - list of event types. Other types of events are returned unfiltered. - If None, `types` filtering is applied to all events. + state_filter (StateFilter): The state filter used to fetch state + from the database. Returns: A Deferred map from ((type, state_key)->Event) @@ -522,7 +514,7 @@ class SyncHandler(object): if last_events: last_event = last_events[-1] state = yield self.get_state_after_event( - last_event, types, filtered_types=filtered_types, + last_event, state_filter=state_filter, ) else: @@ -563,10 +555,11 @@ class SyncHandler(object): last_event = last_events[-1] state_ids = yield self.store.get_state_ids_for_event( - last_event.event_id, [ + last_event.event_id, + state_filter=StateFilter.from_types([ (EventTypes.Name, ''), (EventTypes.CanonicalAlias, ''), - ] + ]), ) # this is heavily cached, thus: fast. @@ -717,8 +710,7 @@ class SyncHandler(object): with Measure(self.clock, "compute_state_delta"): - types = None - filtered_types = None + members_to_fetch = None lazy_load_members = sync_config.filter_collection.lazy_load_members() include_redundant_members = ( @@ -729,16 +721,21 @@ class SyncHandler(object): # We only request state for the members needed to display the # timeline: - types = [ - (EventTypes.Member, state_key) - for state_key in set( - event.sender # FIXME: we also care about invite targets etc. - for event in batch.events - ) - ] + members_to_fetch = set( + event.sender # FIXME: we also care about invite targets etc. + for event in batch.events + ) + + if full_state: + # always make sure we LL ourselves so we know we're in the room + # (if we are) to fix https://github.com/vector-im/riot-web/issues/7209 + # We only need apply this on full state syncs given we disabled + # LL for incr syncs in #3840. + members_to_fetch.add(sync_config.user.to_string()) - # only apply the filtering to room members - filtered_types = [EventTypes.Member] + state_filter = StateFilter.from_lazy_load_member_list(members_to_fetch) + else: + state_filter = StateFilter.all() timeline_state = { (event.type, event.state_key): event.event_id @@ -746,28 +743,19 @@ class SyncHandler(object): } if full_state: - if lazy_load_members: - # always make sure we LL ourselves so we know we're in the room - # (if we are) to fix https://github.com/vector-im/riot-web/issues/7209 - # We only need apply this on full state syncs given we disabled - # LL for incr syncs in #3840. - types.append((EventTypes.Member, sync_config.user.to_string())) - if batch: current_state_ids = yield self.store.get_state_ids_for_event( - batch.events[-1].event_id, types=types, - filtered_types=filtered_types, + batch.events[-1].event_id, state_filter=state_filter, ) state_ids = yield self.store.get_state_ids_for_event( - batch.events[0].event_id, types=types, - filtered_types=filtered_types, + batch.events[0].event_id, state_filter=state_filter, ) else: current_state_ids = yield self.get_state_at( - room_id, stream_position=now_token, types=types, - filtered_types=filtered_types, + room_id, stream_position=now_token, + state_filter=state_filter, ) state_ids = current_state_ids @@ -781,8 +769,7 @@ class SyncHandler(object): ) elif batch.limited: state_at_timeline_start = yield self.store.get_state_ids_for_event( - batch.events[0].event_id, types=types, - filtered_types=filtered_types, + batch.events[0].event_id, state_filter=state_filter, ) # for now, we disable LL for gappy syncs - see @@ -797,17 +784,15 @@ class SyncHandler(object): # members to just be ones which were timeline senders, which then ensures # all of the rest get included in the state block (if we need to know # about them). - types = None - filtered_types = None + state_filter = StateFilter.all() state_at_previous_sync = yield self.get_state_at( - room_id, stream_position=since_token, types=types, - filtered_types=filtered_types, + room_id, stream_position=since_token, + state_filter=state_filter, ) current_state_ids = yield self.store.get_state_ids_for_event( - batch.events[-1].event_id, types=types, - filtered_types=filtered_types, + batch.events[-1].event_id, state_filter=state_filter, ) state_ids = _calculate_state( @@ -821,7 +806,7 @@ class SyncHandler(object): else: state_ids = {} if lazy_load_members: - if types and batch.events: + if members_to_fetch and batch.events: # We're returning an incremental sync, with no # "gap" since the previous sync, so normally there would be # no state to return. @@ -831,8 +816,12 @@ class SyncHandler(object): # timeline here, and then dedupe any redundant ones below. state_ids = yield self.store.get_state_ids_for_event( - batch.events[0].event_id, types=types, - filtered_types=None, # we only want members! + batch.events[0].event_id, + # we only want members! + state_filter=StateFilter.from_types( + (EventTypes.Member, member) + for member in members_to_fetch + ), ) if lazy_load_members and not include_redundant_members: diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index 663934efd0..fcfe7857f6 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -33,6 +33,7 @@ from synapse.http.servlet import ( parse_json_object_from_request, parse_string, ) +from synapse.storage.state import StateFilter from synapse.streams.config import PaginationConfig from synapse.types import RoomAlias, RoomID, StreamToken, ThirdPartyInstanceID, UserID @@ -409,7 +410,7 @@ class RoomMemberListRestServlet(ClientV1RestServlet): room_id=room_id, user_id=requester.user.to_string(), at_token=at_token, - types=[(EventTypes.Member, None)], + state_filter=StateFilter.from_types([(EventTypes.Member, None)]), ) chunk = [] diff --git a/synapse/storage/events.py b/synapse/storage/events.py index c780f55277..8881b009df 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -2089,7 +2089,7 @@ class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore for sg in remaining_state_groups: logger.info("[purge] de-delta-ing remaining state group %s", sg) curr_state = self._get_state_groups_from_groups_txn( - txn, [sg], types=None + txn, [sg], ) curr_state = curr_state[sg] diff --git a/synapse/storage/state.py b/synapse/storage/state.py index 3f4cbd61c4..ef65929bb2 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -19,6 +19,8 @@ from collections import namedtuple from six import iteritems, itervalues from six.moves import range +import attr + from twisted.internet import defer from synapse.api.constants import EventTypes @@ -48,6 +50,318 @@ class _GetStateGroupDelta(namedtuple("_GetStateGroupDelta", ("prev_group", "delt return len(self.delta_ids) if self.delta_ids else 0 +@attr.s(slots=True) +class StateFilter(object): + """A filter used when querying for state. + + Attributes: + types (dict[str, set[str]|None]): Map from type to set of state keys (or + None). This specifies which state_keys for the given type to fetch + from the DB. If None then all events with that type are fetched. If + the set is empty then no events with that type are fetched. + include_others (bool): Whether to fetch events with types that do not + appear in `types`. + """ + + types = attr.ib() + include_others = attr.ib(default=False) + + def __attrs_post_init__(self): + # If `include_others` is set we canonicalise the filter by removing + # wildcards from the types dictionary + if self.include_others: + self.types = { + k: v for k, v in iteritems(self.types) + if v is not None + } + + @staticmethod + def all(): + """Creates a filter that fetches everything. + + Returns: + StateFilter + """ + return StateFilter(types={}, include_others=True) + + @staticmethod + def none(): + """Creates a filter that fetches nothing. + + Returns: + StateFilter + """ + return StateFilter(types={}, include_others=False) + + @staticmethod + def from_types(types): + """Creates a filter that only fetches the given types + + Args: + types (Iterable[tuple[str, str|None]]): A list of type and state + keys to fetch. A state_key of None fetches everything for + that type + + Returns: + StateFilter + """ + type_dict = {} + for typ, s in types: + if typ in type_dict: + if type_dict[typ] is None: + continue + + if s is None: + type_dict[typ] = None + continue + + type_dict.setdefault(typ, set()).add(s) + + return StateFilter(types=type_dict) + + @staticmethod + def from_lazy_load_member_list(members): + """Creates a filter that returns all non-member events, plus the member + events for the given users + + Args: + members (iterable[str]): Set of user IDs + + Returns: + StateFilter + """ + return StateFilter( + types={EventTypes.Member: set(members)}, + include_others=True, + ) + + def return_expanded(self): + """Creates a new StateFilter where type wild cards have been removed + (except for memberships). The returned filter is a superset of the + current one, i.e. anything that passes the current filter will pass + the returned filter. + + This helps the caching as the DictionaryCache knows if it has *all* the + state, but does not know if it has all of the keys of a particular type, + which makes wildcard lookups expensive unless we have a complete cache. + Hence, if we are doing a wildcard lookup, populate the cache fully so + that we can do an efficient lookup next time. + + Note that since we have two caches, one for membership events and one for + other events, we can be a bit more clever than simply returning + `StateFilter.all()` if `has_wildcards()` is True. + + We return a StateFilter where: + 1. the list of membership events to return is the same + 2. if there is a wildcard that matches non-member events we + return all non-member events + + Returns: + StateFilter + """ + + if self.is_full(): + # If we're going to return everything then there's nothing to do + return self + + if not self.has_wildcards(): + # If there are no wild cards, there's nothing to do + return self + + if EventTypes.Member in self.types: + get_all_members = self.types[EventTypes.Member] is None + else: + get_all_members = self.include_others + + has_non_member_wildcard = self.include_others or any( + state_keys is None + for t, state_keys in iteritems(self.types) + if t != EventTypes.Member + ) + + if not has_non_member_wildcard: + # If there are no non-member wild cards we can just return ourselves + return self + + if get_all_members: + # We want to return everything. + return StateFilter.all() + else: + # We want to return all non-members, but only particular + # memberships + return StateFilter( + types={EventTypes.Member: self.types[EventTypes.Member]}, + include_others=True, + ) + + def make_sql_filter_clause(self): + """Converts the filter to an SQL clause. + + For example: + + f = StateFilter.from_types([("m.room.create", "")]) + clause, args = f.make_sql_filter_clause() + clause == "(type = ? AND state_key = ?)" + args == ['m.room.create', ''] + + + Returns: + tuple[str, list]: The SQL string (may be empty) and arguments. An + empty SQL string is returned when the filter matches everything + (i.e. is "full"). + """ + + where_clause = "" + where_args = [] + + if self.is_full(): + return where_clause, where_args + + if not self.include_others and not self.types: + # i.e. this is an empty filter, so we need to return a clause that + # will match nothing + return "1 = 2", [] + + # First we build up a lost of clauses for each type/state_key combo + clauses = [] + for etype, state_keys in iteritems(self.types): + if state_keys is None: + clauses.append("(type = ?)") + where_args.append(etype) + continue + + for state_key in state_keys: + clauses.append("(type = ? AND state_key = ?)") + where_args.extend((etype, state_key)) + + # This will match anything that appears in `self.types` + where_clause = " OR ".join(clauses) + + # If we want to include stuff that's not in the types dict then we add + # a `OR type NOT IN (...)` clause to the end. + if self.include_others: + if where_clause: + where_clause += " OR " + + where_clause += "type NOT IN (%s)" % ( + ",".join(["?"] * len(self.types)), + ) + where_args.extend(self.types) + + return where_clause, where_args + + def max_entries_returned(self): + """Returns the maximum number of entries this filter will return if + known, otherwise returns None. + + For example a simple state filter asking for `("m.room.create", "")` + will return 1, whereas the default state filter will return None. + + This is used to bail out early if the right number of entries have been + fetched. + """ + if self.has_wildcards(): + return None + + return len(self.concrete_types()) + + def filter_state(self, state_dict): + """Returns the state filtered with by this StateFilter + + Args: + state (dict[tuple[str, str], Any]): The state map to filter + + Returns: + dict[tuple[str, str], Any]: The filtered state map + """ + if self.is_full(): + return dict(state_dict) + + filtered_state = {} + for k, v in iteritems(state_dict): + typ, state_key = k + if typ in self.types: + state_keys = self.types[typ] + if state_keys is None or state_key in state_keys: + filtered_state[k] = v + elif self.include_others: + filtered_state[k] = v + + return filtered_state + + def is_full(self): + """Whether this filter fetches everything or not + + Returns: + bool + """ + return self.include_others and not self.types + + def has_wildcards(self): + """Whether the filter includes wildcards or is attempting to fetch + specific state. + + Returns: + bool + """ + + return ( + self.include_others + or any( + state_keys is None + for state_keys in itervalues(self.types) + ) + ) + + def concrete_types(self): + """Returns a list of concrete type/state_keys (i.e. not None) that + will be fetched. This will be a complete list if `has_wildcards` + returns False, but otherwise will be a subset (or even empty). + + Returns: + list[tuple[str,str]] + """ + return [ + (t, s) + for t, state_keys in iteritems(self.types) + if state_keys is not None + for s in state_keys + ] + + def get_member_split(self): + """Return the filter split into two: one which assumes it's exclusively + matching against member state, and one which assumes it's matching + against non member state. + + This is useful due to the returned filters giving correct results for + `is_full()`, `has_wildcards()`, etc, when operating against maps that + either exclusively contain member events or only contain non-member + events. (Which is the case when dealing with the member vs non-member + state caches). + + Returns: + tuple[StateFilter, StateFilter]: The member and non member filters + """ + + if EventTypes.Member in self.types: + state_keys = self.types[EventTypes.Member] + if state_keys is None: + member_filter = StateFilter.all() + else: + member_filter = StateFilter({EventTypes.Member: state_keys}) + elif self.include_others: + member_filter = StateFilter.all() + else: + member_filter = StateFilter.none() + + non_member_filter = StateFilter( + types={k: v for k, v in iteritems(self.types) if k != EventTypes.Member}, + include_others=self.include_others, + ) + + return member_filter, non_member_filter + + # this inherits from EventsWorkerStore because it calls self.get_events class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): """The parts of StateGroupStore that can be called from workers. @@ -152,61 +466,41 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): ) # FIXME: how should this be cached? - def get_filtered_current_state_ids(self, room_id, types, filtered_types=None): + def get_filtered_current_state_ids(self, room_id, state_filter=StateFilter.all()): """Get the current state event of a given type for a room based on the current_state_events table. This may not be as up-to-date as the result of doing a fresh state resolution as per state_handler.get_current_state + Args: room_id (str) - types (list[(Str, (Str|None))]): List of (type, state_key) tuples - which are used to filter the state fetched. `state_key` may be - None, which matches any `state_key` - filtered_types (list[Str]|None): List of types to apply the above filter to. + state_filter (StateFilter): The state filter used to fetch state + from the database. + Returns: - deferred: dict of (type, state_key) -> event + Deferred[dict[tuple[str, str], str]]: Map from type/state_key to + event ID. """ - include_other_types = False if filtered_types is None else True - def _get_filtered_current_state_ids_txn(txn): results = {} - sql = """SELECT type, state_key, event_id FROM current_state_events - WHERE room_id = ? %s""" - # Turns out that postgres doesn't like doing a list of OR's and - # is about 1000x slower, so we just issue a query for each specific - # type seperately. - if types: - clause_to_args = [ - ( - "AND type = ? AND state_key = ?", - (etype, state_key) - ) if state_key is not None else ( - "AND type = ?", - (etype,) - ) - for etype, state_key in types - ] - - if include_other_types: - unique_types = set(filtered_types) - clause_to_args.append( - ( - "AND type <> ? " * len(unique_types), - list(unique_types) - ) - ) - else: - # If types is None we fetch all the state, and so just use an - # empty where clause with no extra args. - clause_to_args = [("", [])] - for where_clause, where_args in clause_to_args: - args = [room_id] - args.extend(where_args) - txn.execute(sql % (where_clause,), args) - for row in txn: - typ, state_key, event_id = row - key = (intern_string(typ), intern_string(state_key)) - results[key] = event_id + sql = """ + SELECT type, state_key, event_id FROM current_state_events + WHERE room_id = ? + """ + + where_clause, where_args = state_filter.make_sql_filter_clause() + + if where_clause: + sql += " AND (%s)" % (where_clause,) + + args = [room_id] + args.extend(where_args) + txn.execute(sql, args) + for row in txn: + typ, state_key, event_id = row + key = (intern_string(typ), intern_string(state_key)) + results[key] = event_id + return results return self.runInteraction( @@ -322,20 +616,14 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): }) @defer.inlineCallbacks - def _get_state_groups_from_groups(self, groups, types, members=None): + def _get_state_groups_from_groups(self, groups, state_filter): """Returns the state groups for a given set of groups, filtering on types of state events. Args: groups(list[int]): list of state group IDs to query - types (Iterable[str, str|None]|None): list of 2-tuples of the form - (`type`, `state_key`), where a `state_key` of `None` matches all - state_keys for the `type`. If None, all types are returned. - members (bool|None): If not None, then, in addition to any filtering - implied by types, the results are also filtered to only include - member events (if True), or to exclude member events (if False) - - Returns: + state_filter (StateFilter): The state filter used to fetch state + from the database. Returns: Deferred[dict[int, dict[tuple[str, str], str]]]: dict of state_group_id -> (dict of (type, state_key) -> event id) @@ -346,19 +634,23 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): for chunk in chunks: res = yield self.runInteraction( "_get_state_groups_from_groups", - self._get_state_groups_from_groups_txn, chunk, types, members, + self._get_state_groups_from_groups_txn, chunk, state_filter, ) results.update(res) defer.returnValue(results) def _get_state_groups_from_groups_txn( - self, txn, groups, types=None, members=None, + self, txn, groups, state_filter=StateFilter.all(), ): results = {group: {} for group in groups} - if types is not None: - types = list(set(types)) # deduplicate types list + where_clause, where_args = state_filter.make_sql_filter_clause() + + # Unless the filter clause is empty, we're going to append it after an + # existing where clause + if where_clause: + where_clause = " AND (%s)" % (where_clause,) if isinstance(self.database_engine, PostgresEngine): # Temporarily disable sequential scans in this transaction. This is @@ -374,79 +666,33 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): # group for the given type, state_key. # This may return multiple rows per (type, state_key), but last_value # should be the same. - sql = (""" + sql = """ WITH RECURSIVE state(state_group) AS ( VALUES(?::bigint) UNION ALL SELECT prev_state_group FROM state_group_edges e, state s WHERE s.state_group = e.state_group ) - SELECT type, state_key, last_value(event_id) OVER ( + SELECT DISTINCT type, state_key, last_value(event_id) OVER ( PARTITION BY type, state_key ORDER BY state_group ASC ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING ) AS event_id FROM state_groups_state WHERE state_group IN ( SELECT state_group FROM state ) - %s - """) + """ - if members is True: - sql += " AND type = '%s'" % (EventTypes.Member,) - elif members is False: - sql += " AND type <> '%s'" % (EventTypes.Member,) - - # Turns out that postgres doesn't like doing a list of OR's and - # is about 1000x slower, so we just issue a query for each specific - # type seperately. - if types is not None: - clause_to_args = [ - ( - "AND type = ? AND state_key = ?", - (etype, state_key) - ) if state_key is not None else ( - "AND type = ?", - (etype,) - ) - for etype, state_key in types - ] - else: - # If types is None we fetch all the state, and so just use an - # empty where clause with no extra args. - clause_to_args = [("", [])] - - for where_clause, where_args in clause_to_args: - for group in groups: - args = [group] - args.extend(where_args) + for group in groups: + args = [group] + args.extend(where_args) - txn.execute(sql % (where_clause,), args) - for row in txn: - typ, state_key, event_id = row - key = (typ, state_key) - results[group][key] = event_id + txn.execute(sql + where_clause, args) + for row in txn: + typ, state_key, event_id = row + key = (typ, state_key) + results[group][key] = event_id else: - where_args = [] - where_clauses = [] - wildcard_types = False - if types is not None: - for typ in types: - if typ[1] is None: - where_clauses.append("(type = ?)") - where_args.append(typ[0]) - wildcard_types = True - else: - where_clauses.append("(type = ? AND state_key = ?)") - where_args.extend([typ[0], typ[1]]) - - where_clause = "AND (%s)" % (" OR ".join(where_clauses)) - else: - where_clause = "" - - if members is True: - where_clause += " AND type = '%s'" % EventTypes.Member - elif members is False: - where_clause += " AND type <> '%s'" % EventTypes.Member + max_entries_returned = state_filter.max_entries_returned() # We don't use WITH RECURSIVE on sqlite3 as there are distributions # that ship with an sqlite3 version that doesn't support it (e.g. wheezy) @@ -460,12 +706,11 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): # without the right indices (which we can't add until # after we finish deduping state, which requires this func) args = [next_group] - if types: - args.extend(where_args) + args.extend(where_args) txn.execute( "SELECT type, state_key, event_id FROM state_groups_state" - " WHERE state_group = ? %s" % (where_clause,), + " WHERE state_group = ? " + where_clause, args ) results[group].update( @@ -481,9 +726,8 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): # wildcards (i.e. Nones) in which case we have to do an exhaustive # search if ( - types is not None and - not wildcard_types and - len(results[group]) == len(types) + max_entries_returned is not None and + len(results[group]) == max_entries_returned ): break @@ -498,20 +742,14 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): return results @defer.inlineCallbacks - def get_state_for_events(self, event_ids, types, filtered_types=None): + def get_state_for_events(self, event_ids, state_filter=StateFilter.all()): """Given a list of event_ids and type tuples, return a list of state - dicts for each event. The state dicts will only have the type/state_keys - that are in the `types` list. + dicts for each event. Args: event_ids (list[string]) - types (list[(str, str|None)]|None): List of (type, state_key) tuples - which are used to filter the state fetched. If `state_key` is None, - all events are returned of the given type. - May be None, which matches any key. - filtered_types(list[str]|None): Only apply filtering via `types` to this - list of event types. Other types of events are returned unfiltered. - If None, `types` filtering is applied to all events. + state_filter (StateFilter): The state filter used to fetch state + from the database. Returns: deferred: A dict of (event_id) -> (type, state_key) -> [state_events] @@ -521,7 +759,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): ) groups = set(itervalues(event_to_groups)) - group_to_state = yield self._get_state_for_groups(groups, types, filtered_types) + group_to_state = yield self._get_state_for_groups(groups, state_filter) state_event_map = yield self.get_events( [ev_id for sd in itervalues(group_to_state) for ev_id in itervalues(sd)], @@ -540,20 +778,15 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): defer.returnValue({event: event_to_state[event] for event in event_ids}) @defer.inlineCallbacks - def get_state_ids_for_events(self, event_ids, types=None, filtered_types=None): + def get_state_ids_for_events(self, event_ids, state_filter=StateFilter.all()): """ Get the state dicts corresponding to a list of events, containing the event_ids of the state events (as opposed to the events themselves) Args: event_ids(list(str)): events whose state should be returned - types(list[(str, str|None)]|None): List of (type, state_key) tuples - which are used to filter the state fetched. If `state_key` is None, - all events are returned of the given type. - May be None, which matches any key. - filtered_types(list[str]|None): Only apply filtering via `types` to this - list of event types. Other types of events are returned unfiltered. - If None, `types` filtering is applied to all events. + state_filter (StateFilter): The state filter used to fetch state + from the database. Returns: A deferred dict from event_id -> (type, state_key) -> event_id @@ -563,7 +796,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): ) groups = set(itervalues(event_to_groups)) - group_to_state = yield self._get_state_for_groups(groups, types, filtered_types) + group_to_state = yield self._get_state_for_groups(groups, state_filter) event_to_state = { event_id: group_to_state[group] @@ -573,45 +806,35 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): defer.returnValue({event: event_to_state[event] for event in event_ids}) @defer.inlineCallbacks - def get_state_for_event(self, event_id, types=None, filtered_types=None): + def get_state_for_event(self, event_id, state_filter=StateFilter.all()): """ Get the state dict corresponding to a particular event Args: event_id(str): event whose state should be returned - types(list[(str, str|None)]|None): List of (type, state_key) tuples - which are used to filter the state fetched. If `state_key` is None, - all events are returned of the given type. - May be None, which matches any key. - filtered_types(list[str]|None): Only apply filtering via `types` to this - list of event types. Other types of events are returned unfiltered. - If None, `types` filtering is applied to all events. + state_filter (StateFilter): The state filter used to fetch state + from the database. Returns: A deferred dict from (type, state_key) -> state_event """ - state_map = yield self.get_state_for_events([event_id], types, filtered_types) + state_map = yield self.get_state_for_events([event_id], state_filter) defer.returnValue(state_map[event_id]) @defer.inlineCallbacks - def get_state_ids_for_event(self, event_id, types=None, filtered_types=None): + def get_state_ids_for_event(self, event_id, state_filter=StateFilter.all()): """ Get the state dict corresponding to a particular event Args: event_id(str): event whose state should be returned - types(list[(str, str|None)]|None): List of (type, state_key) tuples - which are used to filter the state fetched. If `state_key` is None, - all events are returned of the given type. - May be None, which matches any key. - filtered_types(list[str]|None): Only apply filtering via `types` to this - list of event types. Other types of events are returned unfiltered. - If None, `types` filtering is applied to all events. + state_filter (StateFilter): The state filter used to fetch state + from the database. Returns: A deferred dict from (type, state_key) -> state_event """ - state_map = yield self.get_state_ids_for_events([event_id], types, filtered_types) + state_map = yield self.get_state_ids_for_events([event_id], state_filter) defer.returnValue(state_map[event_id]) @cached(max_entries=50000) @@ -642,18 +865,14 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): defer.returnValue({row["event_id"]: row["state_group"] for row in rows}) - def _get_some_state_from_cache(self, cache, group, types, filtered_types=None): + def _get_state_for_group_using_cache(self, cache, group, state_filter): """Checks if group is in cache. See `_get_state_for_groups` Args: cache(DictionaryCache): the state group cache to use group(int): The state group to lookup - types(list[str, str|None]): List of 2-tuples of the form - (`type`, `state_key`), where a `state_key` of `None` matches all - state_keys for the `type`. - filtered_types(list[str]|None): Only apply filtering via `types` to this - list of event types. Other types of events are returned unfiltered. - If None, `types` filtering is applied to all events. + state_filter (StateFilter): The state filter used to fetch state + from the database. Returns 2-tuple (`state_dict`, `got_all`). `got_all` is a bool indicating if we successfully retrieved all @@ -662,124 +881,102 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): """ is_all, known_absent, state_dict_ids = cache.get(group) - type_to_key = {} + if is_all or state_filter.is_full(): + # Either we have everything or want everything, either way + # `is_all` tells us whether we've gotten everything. + return state_filter.filter_state(state_dict_ids), is_all # tracks whether any of our requested types are missing from the cache missing_types = False - for typ, state_key in types: - key = (typ, state_key) - - if ( - state_key is None or - (filtered_types is not None and typ not in filtered_types) - ): - type_to_key[typ] = None - # we mark the type as missing from the cache because - # when the cache was populated it might have been done with a - # restricted set of state_keys, so the wildcard will not work - # and the cache may be incomplete. - missing_types = True - else: - if type_to_key.get(typ, object()) is not None: - type_to_key.setdefault(typ, set()).add(state_key) - + if state_filter.has_wildcards(): + # We don't know if we fetched all the state keys for the types in + # the filter that are wildcards, so we have to assume that we may + # have missed some. + missing_types = True + else: + # There aren't any wild cards, so `concrete_types()` returns the + # complete list of event types we're wanting. + for key in state_filter.concrete_types(): if key not in state_dict_ids and key not in known_absent: missing_types = True + break - sentinel = object() - - def include(typ, state_key): - valid_state_keys = type_to_key.get(typ, sentinel) - if valid_state_keys is sentinel: - return filtered_types is not None and typ not in filtered_types - if valid_state_keys is None: - return True - if state_key in valid_state_keys: - return True - return False - - got_all = is_all - if not got_all: - # the cache is incomplete. We may still have got all the results we need, if - # we don't have any wildcards in the match list. - if not missing_types and filtered_types is None: - got_all = True - - return { - k: v for k, v in iteritems(state_dict_ids) - if include(k[0], k[1]) - }, got_all - - def _get_all_state_from_cache(self, cache, group): - """Checks if group is in cache. See `_get_state_for_groups` - - Returns 2-tuple (`state_dict`, `got_all`). `got_all` is a bool - indicating if we successfully retrieved all requests state from the - cache, if False we need to query the DB for the missing state. - - Args: - cache(DictionaryCache): the state group cache to use - group: The state group to lookup - """ - is_all, _, state_dict_ids = cache.get(group) - - return state_dict_ids, is_all + return state_filter.filter_state(state_dict_ids), not missing_types @defer.inlineCallbacks - def _get_state_for_groups(self, groups, types=None, filtered_types=None): + def _get_state_for_groups(self, groups, state_filter=StateFilter.all()): """Gets the state at each of a list of state groups, optionally filtering by type/state_key Args: groups (iterable[int]): list of state groups for which we want to get the state. - types (None|iterable[(str, None|str)]): - indicates the state type/keys required. If None, the whole - state is fetched and returned. - - Otherwise, each entry should be a `(type, state_key)` tuple to - include in the response. A `state_key` of None is a wildcard - meaning that we require all state with that type. - filtered_types(list[str]|None): Only apply filtering via `types` to this - list of event types. Other types of events are returned unfiltered. - If None, `types` filtering is applied to all events. - + state_filter (StateFilter): The state filter used to fetch state + from the database. Returns: Deferred[dict[int, dict[tuple[str, str], str]]]: dict of state_group_id -> (dict of (type, state_key) -> event id) """ - if types is not None: - non_member_types = [t for t in types if t[0] != EventTypes.Member] - if filtered_types is not None and EventTypes.Member not in filtered_types: - # we want all of the membership events - member_types = None - else: - member_types = [t for t in types if t[0] == EventTypes.Member] - - else: - non_member_types = None - member_types = None + member_filter, non_member_filter = state_filter.get_member_split() - non_member_state = yield self._get_state_for_groups_using_cache( - groups, self._state_group_cache, non_member_types, filtered_types, + # Now we look them up in the member and non-member caches + non_member_state, incomplete_groups_nm, = ( + yield self._get_state_for_groups_using_cache( + groups, self._state_group_cache, + state_filter=non_member_filter, + ) ) - # XXX: we could skip this entirely if member_types is [] - member_state = yield self._get_state_for_groups_using_cache( - # we set filtered_types=None as member_state only ever contain members. - groups, self._state_group_members_cache, member_types, None, + + member_state, incomplete_groups_m, = ( + yield self._get_state_for_groups_using_cache( + groups, self._state_group_members_cache, + state_filter=member_filter, + ) ) - state = non_member_state + state = dict(non_member_state) for group in groups: state[group].update(member_state[group]) + # Now fetch any missing groups from the database + + incomplete_groups = incomplete_groups_m | incomplete_groups_nm + + if not incomplete_groups: + defer.returnValue(state) + + cache_sequence_nm = self._state_group_cache.sequence + cache_sequence_m = self._state_group_members_cache.sequence + + # Help the cache hit ratio by expanding the filter a bit + db_state_filter = state_filter.return_expanded() + + group_to_state_dict = yield self._get_state_groups_from_groups( + list(incomplete_groups), + state_filter=db_state_filter, + ) + + # Now lets update the caches + self._insert_into_cache( + group_to_state_dict, + db_state_filter, + cache_seq_num_members=cache_sequence_m, + cache_seq_num_non_members=cache_sequence_nm, + ) + + # And finally update the result dict, by filtering out any extra + # stuff we pulled out of the database. + for group, group_state_dict in iteritems(group_to_state_dict): + # We just replace any existing entries, as we will have loaded + # everything we need from the database anyway. + state[group] = state_filter.filter_state(group_state_dict) + defer.returnValue(state) - @defer.inlineCallbacks def _get_state_for_groups_using_cache( - self, groups, cache, types=None, filtered_types=None + self, groups, cache, state_filter, ): """Gets the state at each of a list of state groups, optionally filtering by type/state_key, querying from a specific cache. @@ -790,89 +987,85 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): cache (DictionaryCache): the cache of group ids to state dicts which we will pass through - either the normal state cache or the specific members state cache. - types (None|iterable[(str, None|str)]): - indicates the state type/keys required. If None, the whole - state is fetched and returned. - - Otherwise, each entry should be a `(type, state_key)` tuple to - include in the response. A `state_key` of None is a wildcard - meaning that we require all state with that type. - filtered_types(list[str]|None): Only apply filtering via `types` to this - list of event types. Other types of events are returned unfiltered. - If None, `types` filtering is applied to all events. + state_filter (StateFilter): The state filter used to fetch state + from the database. Returns: - Deferred[dict[int, dict[tuple[str, str], str]]]: - dict of state_group_id -> (dict of (type, state_key) -> event id) + tuple[dict[int, dict[tuple[str, str], str]], set[int]]: Tuple of + dict of state_group_id -> (dict of (type, state_key) -> event id) + of entries in the cache, and the state group ids either missing + from the cache or incomplete. """ - if types: - types = frozenset(types) results = {} - missing_groups = [] - if types is not None: - for group in set(groups): - state_dict_ids, got_all = self._get_some_state_from_cache( - cache, group, types, filtered_types - ) - results[group] = state_dict_ids + incomplete_groups = set() + for group in set(groups): + state_dict_ids, got_all = self._get_state_for_group_using_cache( + cache, group, state_filter + ) + results[group] = state_dict_ids - if not got_all: - missing_groups.append(group) - else: - for group in set(groups): - state_dict_ids, got_all = self._get_all_state_from_cache( - cache, group - ) + if not got_all: + incomplete_groups.add(group) - results[group] = state_dict_ids + return results, incomplete_groups - if not got_all: - missing_groups.append(group) + def _insert_into_cache(self, group_to_state_dict, state_filter, + cache_seq_num_members, cache_seq_num_non_members): + """Inserts results from querying the database into the relevant cache. - if missing_groups: - # Okay, so we have some missing_types, let's fetch them. - cache_seq_num = cache.sequence + Args: + group_to_state_dict (dict): The new entries pulled from database. + Map from state group to state dict + state_filter (StateFilter): The state filter used to fetch state + from the database. + cache_seq_num_members (int): Sequence number of member cache since + last lookup in cache + cache_seq_num_non_members (int): Sequence number of member cache since + last lookup in cache + """ - # the DictionaryCache knows if it has *all* the state, but - # does not know if it has all of the keys of a particular type, - # which makes wildcard lookups expensive unless we have a complete - # cache. Hence, if we are doing a wildcard lookup, populate the - # cache fully so that we can do an efficient lookup next time. + # We need to work out which types we've fetched from the DB for the + # member vs non-member caches. This should be as accurate as possible, + # but can be an underestimate (e.g. when we have wild cards) - if filtered_types or (types and any(k is None for (t, k) in types)): - types_to_fetch = None - else: - types_to_fetch = types + member_filter, non_member_filter = state_filter.get_member_split() + if member_filter.is_full(): + # We fetched all member events + member_types = None + else: + # `concrete_types()` will only return a subset when there are wild + # cards in the filter, but that's fine. + member_types = member_filter.concrete_types() - group_to_state_dict = yield self._get_state_groups_from_groups( - missing_groups, types_to_fetch, cache == self._state_group_members_cache, - ) + if non_member_filter.is_full(): + # We fetched all non member events + non_member_types = None + else: + non_member_types = non_member_filter.concrete_types() + + for group, group_state_dict in iteritems(group_to_state_dict): + state_dict_members = {} + state_dict_non_members = {} - for group, group_state_dict in iteritems(group_to_state_dict): - state_dict = results[group] - - # update the result, filtering by `types`. - if types: - for k, v in iteritems(group_state_dict): - (typ, _) = k - if ( - (k in types or (typ, None) in types) or - (filtered_types and typ not in filtered_types) - ): - state_dict[k] = v + for k, v in iteritems(group_state_dict): + if k[0] == EventTypes.Member: + state_dict_members[k] = v else: - state_dict.update(group_state_dict) - - # update the cache with all the things we fetched from the - # database. - cache.update( - cache_seq_num, - key=group, - value=group_state_dict, - fetched_keys=types_to_fetch, - ) + state_dict_non_members[k] = v - defer.returnValue(results) + self._state_group_members_cache.update( + cache_seq_num_members, + key=group, + value=state_dict_members, + fetched_keys=member_types, + ) + + self._state_group_cache.update( + cache_seq_num_non_members, + key=group, + value=state_dict_non_members, + fetched_keys=non_member_types, + ) def store_state_group(self, event_id, room_id, prev_group, delta_ids, current_state_ids): @@ -1181,12 +1374,12 @@ class StateStore(StateGroupWorkerStore, BackgroundUpdateStore): continue prev_state = self._get_state_groups_from_groups_txn( - txn, [prev_group], types=None + txn, [prev_group], ) prev_state = prev_state[prev_group] curr_state = self._get_state_groups_from_groups_txn( - txn, [state_group], types=None + txn, [state_group], ) curr_state = curr_state[state_group] diff --git a/synapse/visibility.py b/synapse/visibility.py index 43f48196be..0281a7c919 100644 --- a/synapse/visibility.py +++ b/synapse/visibility.py @@ -23,6 +23,7 @@ from twisted.internet import defer from synapse.api.constants import EventTypes, Membership from synapse.events.utils import prune_event +from synapse.storage.state import StateFilter from synapse.types import get_domain_from_id logger = logging.getLogger(__name__) @@ -72,7 +73,7 @@ def filter_events_for_client(store, user_id, events, is_peeking=False, ) event_id_to_state = yield store.get_state_for_events( frozenset(e.event_id for e in events), - types=types, + state_filter=StateFilter.from_types(types), ) ignore_dict_content = yield store.get_global_account_data_by_type_for_user( @@ -273,8 +274,8 @@ def filter_events_for_server(store, server_name, events): # need to check membership (as we know the server is in the room). event_to_state_ids = yield store.get_state_ids_for_events( frozenset(e.event_id for e in events), - types=( - (EventTypes.RoomHistoryVisibility, ""), + state_filter=StateFilter.from_types( + types=((EventTypes.RoomHistoryVisibility, ""),), ) ) @@ -314,9 +315,11 @@ def filter_events_for_server(store, server_name, events): # of the history vis and membership state at those events. event_to_state_ids = yield store.get_state_ids_for_events( frozenset(e.event_id for e in events), - types=( - (EventTypes.RoomHistoryVisibility, ""), - (EventTypes.Member, None), + state_filter=StateFilter.from_types( + types=( + (EventTypes.RoomHistoryVisibility, ""), + (EventTypes.Member, None), + ), ) ) diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py index b9c5b39d59..086a39d834 100644 --- a/tests/storage/test_state.py +++ b/tests/storage/test_state.py @@ -18,6 +18,7 @@ import logging from twisted.internet import defer from synapse.api.constants import EventTypes, Membership +from synapse.storage.state import StateFilter from synapse.types import RoomID, UserID import tests.unittest @@ -148,7 +149,7 @@ class StateStoreTestCase(tests.unittest.TestCase): # check we get the full state as of the final event state = yield self.store.get_state_for_event( - e5.event_id, None, filtered_types=None + e5.event_id, ) self.assertIsNotNone(e4) @@ -166,33 +167,35 @@ class StateStoreTestCase(tests.unittest.TestCase): # check we can filter to the m.room.name event (with a '' state key) state = yield self.store.get_state_for_event( - e5.event_id, [(EventTypes.Name, '')], filtered_types=None + e5.event_id, StateFilter.from_types([(EventTypes.Name, '')]) ) self.assertStateMapEqual({(e2.type, e2.state_key): e2}, state) # check we can filter to the m.room.name event (with a wildcard None state key) state = yield self.store.get_state_for_event( - e5.event_id, [(EventTypes.Name, None)], filtered_types=None + e5.event_id, StateFilter.from_types([(EventTypes.Name, None)]) ) self.assertStateMapEqual({(e2.type, e2.state_key): e2}, state) # check we can grab the m.room.member events (with a wildcard None state key) state = yield self.store.get_state_for_event( - e5.event_id, [(EventTypes.Member, None)], filtered_types=None + e5.event_id, StateFilter.from_types([(EventTypes.Member, None)]) ) self.assertStateMapEqual( {(e3.type, e3.state_key): e3, (e5.type, e5.state_key): e5}, state ) - # check we can use filtered_types to grab a specific room member - # without filtering out the other event types + # check we can grab a specific room member without filtering out the + # other event types state = yield self.store.get_state_for_event( e5.event_id, - [(EventTypes.Member, self.u_alice.to_string())], - filtered_types=[EventTypes.Member], + state_filter=StateFilter( + types={EventTypes.Member: {self.u_alice.to_string()}}, + include_others=True, + ) ) self.assertStateMapEqual( @@ -204,10 +207,12 @@ class StateStoreTestCase(tests.unittest.TestCase): state, ) - # check that types=[], filtered_types=[EventTypes.Member] - # doesn't return all members + # check that we can grab everything except members state = yield self.store.get_state_for_event( - e5.event_id, [], filtered_types=[EventTypes.Member] + e5.event_id, state_filter=StateFilter( + types={EventTypes.Member: set()}, + include_others=True, + ), ) self.assertStateMapEqual( @@ -215,16 +220,21 @@ class StateStoreTestCase(tests.unittest.TestCase): ) ####################################################### - # _get_some_state_from_cache tests against a full cache + # _get_state_for_group_using_cache tests against a full cache ####################################################### room_id = self.room.to_string() group_ids = yield self.store.get_state_groups_ids(room_id, [e5.event_id]) group = list(group_ids.keys())[0] - # test _get_some_state_from_cache correctly filters out members with types=[] - (state_dict, is_all) = yield self.store._get_some_state_from_cache( - self.store._state_group_cache, group, [], filtered_types=[EventTypes.Member] + # test _get_state_for_group_using_cache correctly filters out members + # with types=[] + (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( + self.store._state_group_cache, group, + state_filter=StateFilter( + types={EventTypes.Member: set()}, + include_others=True, + ), ) self.assertEqual(is_all, True) @@ -236,22 +246,27 @@ class StateStoreTestCase(tests.unittest.TestCase): state_dict, ) - (state_dict, is_all) = yield self.store._get_some_state_from_cache( + (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( self.store._state_group_members_cache, group, - [], - filtered_types=[EventTypes.Member], + state_filter=StateFilter( + types={EventTypes.Member: set()}, + include_others=True, + ), ) self.assertEqual(is_all, True) self.assertDictEqual({}, state_dict) - # test _get_some_state_from_cache correctly filters in members with wildcard types - (state_dict, is_all) = yield self.store._get_some_state_from_cache( + # test _get_state_for_group_using_cache correctly filters in members + # with wildcard types + (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( self.store._state_group_cache, group, - [(EventTypes.Member, None)], - filtered_types=[EventTypes.Member], + state_filter=StateFilter( + types={EventTypes.Member: None}, + include_others=True, + ), ) self.assertEqual(is_all, True) @@ -263,11 +278,13 @@ class StateStoreTestCase(tests.unittest.TestCase): state_dict, ) - (state_dict, is_all) = yield self.store._get_some_state_from_cache( + (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( self.store._state_group_members_cache, group, - [(EventTypes.Member, None)], - filtered_types=[EventTypes.Member], + state_filter=StateFilter( + types={EventTypes.Member: None}, + include_others=True, + ), ) self.assertEqual(is_all, True) @@ -280,12 +297,15 @@ class StateStoreTestCase(tests.unittest.TestCase): state_dict, ) - # test _get_some_state_from_cache correctly filters in members with specific types - (state_dict, is_all) = yield self.store._get_some_state_from_cache( + # test _get_state_for_group_using_cache correctly filters in members + # with specific types + (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( self.store._state_group_cache, group, - [(EventTypes.Member, e5.state_key)], - filtered_types=[EventTypes.Member], + state_filter=StateFilter( + types={EventTypes.Member: {e5.state_key}}, + include_others=True, + ), ) self.assertEqual(is_all, True) @@ -297,23 +317,27 @@ class StateStoreTestCase(tests.unittest.TestCase): state_dict, ) - (state_dict, is_all) = yield self.store._get_some_state_from_cache( + (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( self.store._state_group_members_cache, group, - [(EventTypes.Member, e5.state_key)], - filtered_types=[EventTypes.Member], + state_filter=StateFilter( + types={EventTypes.Member: {e5.state_key}}, + include_others=True, + ), ) self.assertEqual(is_all, True) self.assertDictEqual({(e5.type, e5.state_key): e5.event_id}, state_dict) - # test _get_some_state_from_cache correctly filters in members with specific types - # and no filtered_types - (state_dict, is_all) = yield self.store._get_some_state_from_cache( + # test _get_state_for_group_using_cache correctly filters in members + # with specific types + (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( self.store._state_group_members_cache, group, - [(EventTypes.Member, e5.state_key)], - filtered_types=None, + state_filter=StateFilter( + types={EventTypes.Member: {e5.state_key}}, + include_others=False, + ), ) self.assertEqual(is_all, True) @@ -357,42 +381,54 @@ class StateStoreTestCase(tests.unittest.TestCase): ############################################ # test that things work with a partial cache - # test _get_some_state_from_cache correctly filters out members with types=[] + # test _get_state_for_group_using_cache correctly filters out members + # with types=[] room_id = self.room.to_string() - (state_dict, is_all) = yield self.store._get_some_state_from_cache( - self.store._state_group_cache, group, [], filtered_types=[EventTypes.Member] + (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( + self.store._state_group_cache, group, + state_filter=StateFilter( + types={EventTypes.Member: set()}, + include_others=True, + ), ) self.assertEqual(is_all, False) self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict) room_id = self.room.to_string() - (state_dict, is_all) = yield self.store._get_some_state_from_cache( + (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( self.store._state_group_members_cache, group, - [], - filtered_types=[EventTypes.Member], + state_filter=StateFilter( + types={EventTypes.Member: set()}, + include_others=True, + ), ) self.assertEqual(is_all, True) self.assertDictEqual({}, state_dict) - # test _get_some_state_from_cache correctly filters in members wildcard types - (state_dict, is_all) = yield self.store._get_some_state_from_cache( + # test _get_state_for_group_using_cache correctly filters in members + # wildcard types + (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( self.store._state_group_cache, group, - [(EventTypes.Member, None)], - filtered_types=[EventTypes.Member], + state_filter=StateFilter( + types={EventTypes.Member: None}, + include_others=True, + ), ) self.assertEqual(is_all, False) self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict) - (state_dict, is_all) = yield self.store._get_some_state_from_cache( + (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( self.store._state_group_members_cache, group, - [(EventTypes.Member, None)], - filtered_types=[EventTypes.Member], + state_filter=StateFilter( + types={EventTypes.Member: None}, + include_others=True, + ), ) self.assertEqual(is_all, True) @@ -404,44 +440,53 @@ class StateStoreTestCase(tests.unittest.TestCase): state_dict, ) - # test _get_some_state_from_cache correctly filters in members with specific types - (state_dict, is_all) = yield self.store._get_some_state_from_cache( + # test _get_state_for_group_using_cache correctly filters in members + # with specific types + (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( self.store._state_group_cache, group, - [(EventTypes.Member, e5.state_key)], - filtered_types=[EventTypes.Member], + state_filter=StateFilter( + types={EventTypes.Member: {e5.state_key}}, + include_others=True, + ), ) self.assertEqual(is_all, False) self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict) - (state_dict, is_all) = yield self.store._get_some_state_from_cache( + (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( self.store._state_group_members_cache, group, - [(EventTypes.Member, e5.state_key)], - filtered_types=[EventTypes.Member], + state_filter=StateFilter( + types={EventTypes.Member: {e5.state_key}}, + include_others=True, + ), ) self.assertEqual(is_all, True) self.assertDictEqual({(e5.type, e5.state_key): e5.event_id}, state_dict) - # test _get_some_state_from_cache correctly filters in members with specific types - # and no filtered_types - (state_dict, is_all) = yield self.store._get_some_state_from_cache( + # test _get_state_for_group_using_cache correctly filters in members + # with specific types + (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( self.store._state_group_cache, group, - [(EventTypes.Member, e5.state_key)], - filtered_types=None, + state_filter=StateFilter( + types={EventTypes.Member: {e5.state_key}}, + include_others=False, + ), ) self.assertEqual(is_all, False) self.assertDictEqual({}, state_dict) - (state_dict, is_all) = yield self.store._get_some_state_from_cache( + (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( self.store._state_group_members_cache, group, - [(EventTypes.Member, e5.state_key)], - filtered_types=None, + state_filter=StateFilter( + types={EventTypes.Member: {e5.state_key}}, + include_others=False, + ), ) self.assertEqual(is_all, True) -- cgit 1.4.1 From 379376e5e6ae242d9a69f0ec738758a649058a82 Mon Sep 17 00:00:00 2001 From: Cédric Laudrel Date: Thu, 25 Oct 2018 16:36:07 +0200 Subject: Make Docker image listening on ipv6 as well as ipv4 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Cédric Laudrel --- changelog.d/4089.feature | 1 + docker/conf/homeserver.yaml | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) create mode 100644 changelog.d/4089.feature diff --git a/changelog.d/4089.feature b/changelog.d/4089.feature new file mode 100644 index 0000000000..62c9d839bb --- /dev/null +++ b/changelog.d/4089.feature @@ -0,0 +1 @@ + Configure Docker image to listen on both ipv4 and ipv6. diff --git a/docker/conf/homeserver.yaml b/docker/conf/homeserver.yaml index a38b929f50..1b0f655d26 100644 --- a/docker/conf/homeserver.yaml +++ b/docker/conf/homeserver.yaml @@ -21,7 +21,7 @@ listeners: {% if not SYNAPSE_NO_TLS %} - port: 8448 - bind_addresses: ['0.0.0.0'] + bind_addresses: ['::'] type: http tls: true x_forwarded: false @@ -34,7 +34,7 @@ listeners: - port: 8008 tls: false - bind_addresses: ['0.0.0.0'] + bind_addresses: ['::'] type: http x_forwarded: false -- cgit 1.4.1 From 0ee8d1b64149205de65e3f81ecc8bd38fc71ae95 Mon Sep 17 00:00:00 2001 From: Neil Johnson Date: Mon, 29 Oct 2018 06:46:00 +0000 Subject: wip tests to filter out support user --- synapse/storage/user_directory.py | 6 ++---- tests/storage/test_user_directory.py | 39 ++++++++++++++++++++++++++++++++++++ 2 files changed, 41 insertions(+), 4 deletions(-) diff --git a/synapse/storage/user_directory.py b/synapse/storage/user_directory.py index 5699cea855..cd25e07719 100644 --- a/synapse/storage/user_directory.py +++ b/synapse/storage/user_directory.py @@ -339,7 +339,7 @@ class UserDirectoryStore(SQLBaseStore): rows = yield self._execute("get_all_local_users", None, sql) defer.returnValue([name for name, in rows]) - def add_users_who_share_room(self, room_id, share_private, user_id_tuples): + def add_users_who_share_room(self, room_id, share_private, user_id_tuples_x): """Insert entries into the users_who_share_rooms table. The first user should be a local user. @@ -350,9 +350,7 @@ class UserDirectoryStore(SQLBaseStore): """ def _add_users_who_share_room_txn(txn): support_user = self.hs.config.support_user_id - for ut in user_id_tuples: - if support_user in ut: - user_id_tuples.remove(ut) + user_id_tuples = filter(lambda x: support_user not in x, user_id_tuples_x) self._simple_insert_many_txn( txn, diff --git a/tests/storage/test_user_directory.py b/tests/storage/test_user_directory.py index 0dde1ab2fe..12f64de691 100644 --- a/tests/storage/test_user_directory.py +++ b/tests/storage/test_user_directory.py @@ -75,3 +75,42 @@ class UserDirectoryStoreTestCase(unittest.TestCase): ) finally: self.hs.config.user_directory_search_all_users = False + + @defer.inlineCallbacks + def test_cannot_add_support_user_to_directory(self): + self.hs.config.user_directory_search_all_users = True + self.hs.config.support_user_id = "@support:test" + SUPPORT_USER = self.hs.config.support_user_id + yield self.store.add_profiles_to_user_dir( + "!room:id", + {SUPPORT_USER: ProfileInfo(None, "support")}, + ) + yield self.store.add_users_to_public_room("!room:id", [SUPPORT_USER]) + yield self.store.add_users_who_share_room( + "!room:id", False, ((ALICE, SUPPORT_USER),) + ) + + r = yield self.store.search_user_dir(ALICE, "support", 10) + self.assertFalse(r["limited"]) + self.assertEqual(0, len(r["results"])) + + # add_users_who_share_room + # add_users_to_public_room + # add_profiles_to_user_dir + # update_user_in_user_dir + # update_profile_in_user_dir + # update_user_in_public_user_list + + # yield self.store.add_profiles_to_user_dir( + # "!room:id", + # {SUPPORT_USER: ProfileInfo(None, "support")}, + # ) + # yield self.store.add_profiles_to_user_dir(SUPPORT_USER, + # + # + # + # yield self.store.add_users_to_public_room("!room:id", [SUPPORT_USER]) + # + # yield self.store.add_users_who_share_room( + # "!room:id", False, ((ALICE, SUPPORT_USER), (BOB, SUPPORT_USER)) + # ) -- cgit 1.4.1