From 2d3462714e48dca46dd54b17ca29188a17261e28 Mon Sep 17 00:00:00 2001 From: Daniel Wagner-Hall Date: Tue, 18 Aug 2015 14:22:02 +0100 Subject: Issue macaroons as opaque auth tokens This just replaces random bytes with macaroons. The macaroons are not inspected by the client or server. In particular, they claim to have an expiry time, but nothing verifies that they have not expired. Follow-up commits will actually enforce the expiration, and allow for token refresh. See https://bit.ly/matrix-auth for more information --- tests/utils.py | 2 ++ 1 file changed, 2 insertions(+) (limited to 'tests/utils.py') diff --git a/tests/utils.py b/tests/utils.py index eb035cf48f..80be70b74f 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -44,6 +44,8 @@ def setup_test_homeserver(name="test", datastore=None, config=None, **kargs): config.signing_key = [MockKey()] config.event_cache_size = 1 config.disable_registration = False + config.macaroon_secret_key = "not even a little secret" + config.server_name = "server.under.test" if "clock" not in kargs: kargs["clock"] = MockClock() -- cgit 1.5.1 From 13a6517d89c0619a938321640f331571eba0edc9 Mon Sep 17 00:00:00 2001 From: Daniel Wagner-Hall Date: Thu, 20 Aug 2015 16:01:29 +0100 Subject: s/by_token/by_access_token/g We're about to have two kinds of token, access and refresh --- synapse/api/auth.py | 6 +++--- synapse/storage/registration.py | 6 +++--- tests/api/test_auth.py | 16 ++++++++-------- tests/rest/client/v1/test_presence.py | 8 ++++---- tests/rest/client/v1/test_rooms.py | 28 ++++++++++++++-------------- tests/rest/client/v1/test_typing.py | 4 ++-- tests/rest/client/v1/utils.py | 2 +- tests/rest/client/v2_alpha/__init__.py | 4 ++-- tests/storage/test_registration.py | 4 ++-- tests/utils.py | 2 +- 10 files changed, 40 insertions(+), 40 deletions(-) (limited to 'tests/utils.py') diff --git a/synapse/api/auth.py b/synapse/api/auth.py index 1e3b0fbfb7..3d9237ccc3 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -361,7 +361,7 @@ class Auth(object): except KeyError: pass # normal users won't have the user_id query parameter set. - user_info = yield self.get_user_by_token(access_token) + user_info = yield self.get_user_by_access_token(access_token) user = user_info["user"] device_id = user_info["device_id"] token_id = user_info["token_id"] @@ -390,7 +390,7 @@ class Auth(object): ) @defer.inlineCallbacks - def get_user_by_token(self, token): + def get_user_by_access_token(self, token): """ Get a registered user's ID. Args: @@ -401,7 +401,7 @@ class Auth(object): Raises: AuthError if no user by that token exists or the token is invalid. """ - ret = yield self.store.get_user_by_token(token) + ret = yield self.store.get_user_by_access_token(token) if not ret: raise AuthError( self.TOKEN_NOT_FOUND_HTTP_STATUS, "Unrecognised access token.", diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py index bf803f2c6e..0e404afb7c 100644 --- a/synapse/storage/registration.py +++ b/synapse/storage/registration.py @@ -132,10 +132,10 @@ class RegistrationStore(SQLBaseStore): user_id ) for r in rows: - self.get_user_by_token.invalidate((r,)) + self.get_user_by_access_token.invalidate((r,)) @cached() - def get_user_by_token(self, token): + def get_user_by_access_token(self, token): """Get a user from the given access token. Args: @@ -147,7 +147,7 @@ class RegistrationStore(SQLBaseStore): StoreError if no user was found. """ return self.runInteraction( - "get_user_by_token", + "get_user_by_access_token", self._query_for_auth, token ) diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py index 4f83db5e84..3343c635cc 100644 --- a/tests/api/test_auth.py +++ b/tests/api/test_auth.py @@ -44,7 +44,7 @@ class AuthTestCase(unittest.TestCase): "token_id": "ditto", "admin": False } - self.store.get_user_by_token = Mock(return_value=user_info) + self.store.get_user_by_access_token = Mock(return_value=user_info) request = Mock(args={}) request.args["access_token"] = [self.test_token] @@ -54,7 +54,7 @@ class AuthTestCase(unittest.TestCase): def test_get_user_by_req_user_bad_token(self): self.store.get_app_service_by_token = Mock(return_value=None) - self.store.get_user_by_token = Mock(return_value=None) + self.store.get_user_by_access_token = Mock(return_value=None) request = Mock(args={}) request.args["access_token"] = [self.test_token] @@ -70,7 +70,7 @@ class AuthTestCase(unittest.TestCase): "token_id": "ditto", "admin": False } - self.store.get_user_by_token = Mock(return_value=user_info) + self.store.get_user_by_access_token = Mock(return_value=user_info) request = Mock(args={}) request.requestHeaders.getRawHeaders = Mock(return_value=[""]) @@ -81,7 +81,7 @@ class AuthTestCase(unittest.TestCase): def test_get_user_by_req_appservice_valid_token(self): app_service = Mock(token="foobar", url="a_url", sender=self.test_user) self.store.get_app_service_by_token = Mock(return_value=app_service) - self.store.get_user_by_token = Mock(return_value=None) + self.store.get_user_by_access_token = Mock(return_value=None) request = Mock(args={}) request.args["access_token"] = [self.test_token] @@ -91,7 +91,7 @@ class AuthTestCase(unittest.TestCase): def test_get_user_by_req_appservice_bad_token(self): self.store.get_app_service_by_token = Mock(return_value=None) - self.store.get_user_by_token = Mock(return_value=None) + self.store.get_user_by_access_token = Mock(return_value=None) request = Mock(args={}) request.args["access_token"] = [self.test_token] @@ -102,7 +102,7 @@ class AuthTestCase(unittest.TestCase): def test_get_user_by_req_appservice_missing_token(self): app_service = Mock(token="foobar", url="a_url", sender=self.test_user) self.store.get_app_service_by_token = Mock(return_value=app_service) - self.store.get_user_by_token = Mock(return_value=None) + self.store.get_user_by_access_token = Mock(return_value=None) request = Mock(args={}) request.requestHeaders.getRawHeaders = Mock(return_value=[""]) @@ -115,7 +115,7 @@ class AuthTestCase(unittest.TestCase): app_service = Mock(token="foobar", url="a_url", sender=self.test_user) app_service.is_interested_in_user = Mock(return_value=True) self.store.get_app_service_by_token = Mock(return_value=app_service) - self.store.get_user_by_token = Mock(return_value=None) + self.store.get_user_by_access_token = Mock(return_value=None) request = Mock(args={}) request.args["access_token"] = [self.test_token] @@ -129,7 +129,7 @@ class AuthTestCase(unittest.TestCase): app_service = Mock(token="foobar", url="a_url", sender=self.test_user) app_service.is_interested_in_user = Mock(return_value=False) self.store.get_app_service_by_token = Mock(return_value=app_service) - self.store.get_user_by_token = Mock(return_value=None) + self.store.get_user_by_access_token = Mock(return_value=None) request = Mock(args={}) request.args["access_token"] = [self.test_token] diff --git a/tests/rest/client/v1/test_presence.py b/tests/rest/client/v1/test_presence.py index 089a71568c..0b78a82a66 100644 --- a/tests/rest/client/v1/test_presence.py +++ b/tests/rest/client/v1/test_presence.py @@ -70,7 +70,7 @@ class PresenceStateTestCase(unittest.TestCase): return defer.succeed([]) self.datastore.get_presence_list = get_presence_list - def _get_user_by_token(token=None): + def _get_user_by_access_token(token=None): return { "user": UserID.from_string(myid), "admin": False, @@ -78,7 +78,7 @@ class PresenceStateTestCase(unittest.TestCase): "token_id": 1, } - hs.get_v1auth().get_user_by_token = _get_user_by_token + hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token room_member_handler = hs.handlers.room_member_handler = Mock( spec=[ @@ -159,7 +159,7 @@ class PresenceListTestCase(unittest.TestCase): ) self.datastore.has_presence_state = has_presence_state - def _get_user_by_token(token=None): + def _get_user_by_access_token(token=None): return { "user": UserID.from_string(myid), "admin": False, @@ -173,7 +173,7 @@ class PresenceListTestCase(unittest.TestCase): ] ) - hs.get_v1auth().get_user_by_token = _get_user_by_token + hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token presence.register_servlets(hs, self.mock_resource) diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py index c83348acf9..2e55cc08a1 100644 --- a/tests/rest/client/v1/test_rooms.py +++ b/tests/rest/client/v1/test_rooms.py @@ -54,14 +54,14 @@ class RoomPermissionsTestCase(RestTestCase): hs.get_handlers().federation_handler = Mock() - def _get_user_by_token(token=None): + def _get_user_by_access_token(token=None): return { "user": UserID.from_string(self.auth_user_id), "admin": False, "device_id": None, "token_id": 1, } - hs.get_v1auth().get_user_by_token = _get_user_by_token + hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token def _insert_client_ip(*args, **kwargs): return defer.succeed(None) @@ -441,14 +441,14 @@ class RoomsMemberListTestCase(RestTestCase): self.auth_user_id = self.user_id - def _get_user_by_token(token=None): + def _get_user_by_access_token(token=None): return { "user": UserID.from_string(self.auth_user_id), "admin": False, "device_id": None, "token_id": 1, } - hs.get_v1auth().get_user_by_token = _get_user_by_token + hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token def _insert_client_ip(*args, **kwargs): return defer.succeed(None) @@ -521,14 +521,14 @@ class RoomsCreateTestCase(RestTestCase): hs.get_handlers().federation_handler = Mock() - def _get_user_by_token(token=None): + def _get_user_by_access_token(token=None): return { "user": UserID.from_string(self.auth_user_id), "admin": False, "device_id": None, "token_id": 1, } - hs.get_v1auth().get_user_by_token = _get_user_by_token + hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token def _insert_client_ip(*args, **kwargs): return defer.succeed(None) @@ -614,7 +614,7 @@ class RoomTopicTestCase(RestTestCase): hs.get_handlers().federation_handler = Mock() - def _get_user_by_token(token=None): + def _get_user_by_access_token(token=None): return { "user": UserID.from_string(self.auth_user_id), "admin": False, @@ -622,7 +622,7 @@ class RoomTopicTestCase(RestTestCase): "token_id": 1, } - hs.get_v1auth().get_user_by_token = _get_user_by_token + hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token def _insert_client_ip(*args, **kwargs): return defer.succeed(None) @@ -721,14 +721,14 @@ class RoomMemberStateTestCase(RestTestCase): hs.get_handlers().federation_handler = Mock() - def _get_user_by_token(token=None): + def _get_user_by_access_token(token=None): return { "user": UserID.from_string(self.auth_user_id), "admin": False, "device_id": None, "token_id": 1, } - hs.get_v1auth().get_user_by_token = _get_user_by_token + hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token def _insert_client_ip(*args, **kwargs): return defer.succeed(None) @@ -848,14 +848,14 @@ class RoomMessagesTestCase(RestTestCase): hs.get_handlers().federation_handler = Mock() - def _get_user_by_token(token=None): + def _get_user_by_access_token(token=None): return { "user": UserID.from_string(self.auth_user_id), "admin": False, "device_id": None, "token_id": 1, } - hs.get_v1auth().get_user_by_token = _get_user_by_token + hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token def _insert_client_ip(*args, **kwargs): return defer.succeed(None) @@ -945,14 +945,14 @@ class RoomInitialSyncTestCase(RestTestCase): hs.get_handlers().federation_handler = Mock() - def _get_user_by_token(token=None): + def _get_user_by_access_token(token=None): return { "user": UserID.from_string(self.auth_user_id), "admin": False, "device_id": None, "token_id": 1, } - hs.get_v1auth().get_user_by_token = _get_user_by_token + hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token def _insert_client_ip(*args, **kwargs): return defer.succeed(None) diff --git a/tests/rest/client/v1/test_typing.py b/tests/rest/client/v1/test_typing.py index 7d8b1c2683..dc8bbaaf0e 100644 --- a/tests/rest/client/v1/test_typing.py +++ b/tests/rest/client/v1/test_typing.py @@ -61,7 +61,7 @@ class RoomTypingTestCase(RestTestCase): hs.get_handlers().federation_handler = Mock() - def _get_user_by_token(token=None): + def _get_user_by_access_token(token=None): return { "user": UserID.from_string(self.auth_user_id), "admin": False, @@ -69,7 +69,7 @@ class RoomTypingTestCase(RestTestCase): "token_id": 1, } - hs.get_v1auth().get_user_by_token = _get_user_by_token + hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token def _insert_client_ip(*args, **kwargs): return defer.succeed(None) diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py index 579441fb4a..c472d53043 100644 --- a/tests/rest/client/v1/utils.py +++ b/tests/rest/client/v1/utils.py @@ -37,7 +37,7 @@ class RestTestCase(unittest.TestCase): self.mock_resource = None self.auth_user_id = None - def mock_get_user_by_token(self, token=None): + def mock_get_user_by_access_token(self, token=None): return self.auth_user_id @defer.inlineCallbacks diff --git a/tests/rest/client/v2_alpha/__init__.py b/tests/rest/client/v2_alpha/__init__.py index de5a917e6a..15568b36cd 100644 --- a/tests/rest/client/v2_alpha/__init__.py +++ b/tests/rest/client/v2_alpha/__init__.py @@ -43,14 +43,14 @@ class V2AlphaRestTestCase(unittest.TestCase): resource_for_federation=self.mock_resource, ) - def _get_user_by_token(token=None): + def _get_user_by_access_token(token=None): return { "user": UserID.from_string(self.USER_ID), "admin": False, "device_id": None, "token_id": 1, } - hs.get_auth().get_user_by_token = _get_user_by_token + hs.get_auth().get_user_by_access_token = _get_user_by_access_token for r in self.TO_REGISTER: r.register_servlets(hs, self.mock_resource) diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py index 2702291178..7a24cf898a 100644 --- a/tests/storage/test_registration.py +++ b/tests/storage/test_registration.py @@ -46,7 +46,7 @@ class RegistrationStoreTestCase(unittest.TestCase): (yield self.store.get_user_by_id(self.user_id)) ) - result = yield self.store.get_user_by_token(self.tokens[0]) + result = yield self.store.get_user_by_access_token(self.tokens[0]) self.assertDictContainsSubset( { @@ -64,7 +64,7 @@ class RegistrationStoreTestCase(unittest.TestCase): yield self.store.register(self.user_id, self.tokens[0], self.pwhash) yield self.store.add_access_token_to_user(self.user_id, self.tokens[1]) - result = yield self.store.get_user_by_token(self.tokens[1]) + result = yield self.store.get_user_by_access_token(self.tokens[1]) self.assertDictContainsSubset( { diff --git a/tests/utils.py b/tests/utils.py index 80be70b74f..d0fba2252d 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -277,7 +277,7 @@ class MemoryDataStore(object): raise StoreError(400, "User in use.") self.tokens_to_users[token] = user_id - def get_user_by_token(self, token): + def get_user_by_access_token(self, token): try: return { "name": self.tokens_to_users[token], -- cgit 1.5.1 From a0b181bd17cb7ec2a43ed2dbdeb1bb40f3f4373c Mon Sep 17 00:00:00 2001 From: Daniel Wagner-Hall Date: Tue, 25 Aug 2015 16:23:06 +0100 Subject: Remove completely unused concepts from codebase Removes device_id and ClientInfo device_id is never actually written, and the matrix.org DB has no non-null entries for it. Right now, it's just cluttering up code. This doesn't remove the columns from the database, because that's fiddly. --- synapse/api/auth.py | 17 ++++++--------- synapse/handlers/admin.py | 1 + synapse/handlers/message.py | 9 +++----- synapse/rest/client/v1/admin.py | 2 +- synapse/rest/client/v1/directory.py | 4 ++-- synapse/rest/client/v1/events.py | 4 ++-- synapse/rest/client/v1/initial_sync.py | 2 +- synapse/rest/client/v1/presence.py | 8 +++---- synapse/rest/client/v1/profile.py | 4 ++-- synapse/rest/client/v1/pusher.py | 4 ++-- synapse/rest/client/v1/room.py | 34 ++++++++++++++--------------- synapse/rest/client/v1/voip.py | 2 +- synapse/rest/client/v2_alpha/account.py | 4 ++-- synapse/rest/client/v2_alpha/filter.py | 4 ++-- synapse/rest/client/v2_alpha/keys.py | 6 ++--- synapse/rest/client/v2_alpha/receipts.py | 2 +- synapse/rest/client/v2_alpha/sync.py | 2 +- synapse/rest/media/v0/content_repository.py | 2 +- synapse/rest/media/v1/upload_resource.py | 2 +- synapse/storage/__init__.py | 7 +++--- synapse/storage/registration.py | 5 ++--- synapse/types.py | 4 ---- tests/api/test_auth.py | 8 +++---- tests/rest/client/v1/test_presence.py | 2 -- tests/rest/client/v1/test_rooms.py | 7 ------ tests/rest/client/v1/test_typing.py | 1 - tests/rest/client/v2_alpha/__init__.py | 1 - tests/storage/test_registration.py | 2 -- tests/utils.py | 3 +-- 29 files changed, 63 insertions(+), 90 deletions(-) (limited to 'tests/utils.py') diff --git a/synapse/api/auth.py b/synapse/api/auth.py index 3d9237ccc3..1496db7dff 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -20,7 +20,7 @@ from twisted.internet import defer from synapse.api.constants import EventTypes, Membership, JoinRules from synapse.api.errors import AuthError, Codes, SynapseError from synapse.util.logutils import log_function -from synapse.types import UserID, ClientInfo +from synapse.types import UserID import logging @@ -322,9 +322,9 @@ class Auth(object): Args: request - An HTTP request with an access_token query parameter. Returns: - tuple : of UserID and device string: - User ID object of the user making the request - ClientInfo object of the client instance the user is using + tuple of: + UserID (str) + Access token ID (str) Raises: AuthError if no user by that token exists or the token is invalid. """ @@ -355,7 +355,7 @@ class Auth(object): request.authenticated_entity = user_id defer.returnValue( - (UserID.from_string(user_id), ClientInfo("", "")) + (UserID.from_string(user_id), "") ) return except KeyError: @@ -363,7 +363,6 @@ class Auth(object): user_info = yield self.get_user_by_access_token(access_token) user = user_info["user"] - device_id = user_info["device_id"] token_id = user_info["token_id"] ip_addr = self.hs.get_ip_from_request(request) @@ -375,14 +374,13 @@ class Auth(object): self.store.insert_client_ip( user=user, access_token=access_token, - device_id=user_info["device_id"], ip=ip_addr, user_agent=user_agent ) request.authenticated_entity = user.to_string() - defer.returnValue((user, ClientInfo(device_id, token_id))) + defer.returnValue((user, token_id,)) except KeyError: raise AuthError( self.TOKEN_NOT_FOUND_HTTP_STATUS, "Missing access token.", @@ -396,7 +394,7 @@ class Auth(object): Args: token (str): The access token to get the user by. Returns: - dict : dict that includes the user, device_id, and whether the + dict : dict that includes the user and whether the user is a server admin. Raises: AuthError if no user by that token exists or the token is invalid. @@ -409,7 +407,6 @@ class Auth(object): ) user_info = { "admin": bool(ret.get("admin", False)), - "device_id": ret.get("device_id"), "user": UserID.from_string(ret.get("name")), "token_id": ret.get("token_id", None), } diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py index 1c9e7152c7..d852a18555 100644 --- a/synapse/handlers/admin.py +++ b/synapse/handlers/admin.py @@ -34,6 +34,7 @@ class AdminHandler(BaseHandler): d = {} for r in res: + # Note that device_id is always None device = d.setdefault(r["device_id"], {}) session = device.setdefault(r["access_token"], []) session.append({ diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index f12465fa2c..23b779ad7c 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -183,7 +183,7 @@ class MessageHandler(BaseHandler): @defer.inlineCallbacks def create_and_send_event(self, event_dict, ratelimit=True, - client=None, txn_id=None): + token_id=None, txn_id=None): """ Given a dict from a client, create and handle a new event. Creates an FrozenEvent object, filling out auth_events, prev_events, @@ -217,11 +217,8 @@ class MessageHandler(BaseHandler): builder.content ) - if client is not None: - if client.token_id is not None: - builder.internal_metadata.token_id = client.token_id - if client.device_id is not None: - builder.internal_metadata.device_id = client.device_id + if token_id is not None: + builder.internal_metadata.token_id = token_id if txn_id is not None: builder.internal_metadata.txn_id = txn_id diff --git a/synapse/rest/client/v1/admin.py b/synapse/rest/client/v1/admin.py index 2ce754b028..504b63eab4 100644 --- a/synapse/rest/client/v1/admin.py +++ b/synapse/rest/client/v1/admin.py @@ -31,7 +31,7 @@ class WhoisRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_GET(self, request, user_id): target_user = UserID.from_string(user_id) - auth_user, client = yield self.auth.get_user_by_req(request) + auth_user, _ = yield self.auth.get_user_by_req(request) is_admin = yield self.auth.is_server_admin(auth_user) if not is_admin and target_user != auth_user: diff --git a/synapse/rest/client/v1/directory.py b/synapse/rest/client/v1/directory.py index 6758a888b3..4dcda57c1b 100644 --- a/synapse/rest/client/v1/directory.py +++ b/synapse/rest/client/v1/directory.py @@ -69,7 +69,7 @@ class ClientDirectoryServer(ClientV1RestServlet): try: # try to auth as a user - user, client = yield self.auth.get_user_by_req(request) + user, _ = yield self.auth.get_user_by_req(request) try: user_id = user.to_string() yield dir_handler.create_association( @@ -116,7 +116,7 @@ class ClientDirectoryServer(ClientV1RestServlet): # fallback to default user behaviour if they aren't an AS pass - user, client = yield self.auth.get_user_by_req(request) + user, _ = yield self.auth.get_user_by_req(request) is_admin = yield self.auth.is_server_admin(user) if not is_admin: diff --git a/synapse/rest/client/v1/events.py b/synapse/rest/client/v1/events.py index 77b7c25a03..582148b659 100644 --- a/synapse/rest/client/v1/events.py +++ b/synapse/rest/client/v1/events.py @@ -34,7 +34,7 @@ class EventStreamRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_GET(self, request): - auth_user, client = yield self.auth.get_user_by_req(request) + auth_user, _ = yield self.auth.get_user_by_req(request) try: handler = self.handlers.event_stream_handler pagin_config = PaginationConfig.from_request(request) @@ -71,7 +71,7 @@ class EventRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_GET(self, request, event_id): - auth_user, client = yield self.auth.get_user_by_req(request) + auth_user, _ = yield self.auth.get_user_by_req(request) handler = self.handlers.event_handler event = yield handler.get_event(auth_user, event_id) diff --git a/synapse/rest/client/v1/initial_sync.py b/synapse/rest/client/v1/initial_sync.py index 4a259bba64..4ea4da653c 100644 --- a/synapse/rest/client/v1/initial_sync.py +++ b/synapse/rest/client/v1/initial_sync.py @@ -25,7 +25,7 @@ class InitialSyncRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_GET(self, request): - user, client = yield self.auth.get_user_by_req(request) + user, _ = yield self.auth.get_user_by_req(request) with_feedback = "feedback" in request.args as_client_event = "raw" not in request.args pagination_config = PaginationConfig.from_request(request) diff --git a/synapse/rest/client/v1/presence.py b/synapse/rest/client/v1/presence.py index 78d4f2b128..a770efd841 100644 --- a/synapse/rest/client/v1/presence.py +++ b/synapse/rest/client/v1/presence.py @@ -32,7 +32,7 @@ class PresenceStatusRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_GET(self, request, user_id): - auth_user, client = yield self.auth.get_user_by_req(request) + auth_user, _ = yield self.auth.get_user_by_req(request) user = UserID.from_string(user_id) state = yield self.handlers.presence_handler.get_state( @@ -42,7 +42,7 @@ class PresenceStatusRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_PUT(self, request, user_id): - auth_user, client = yield self.auth.get_user_by_req(request) + auth_user, _ = yield self.auth.get_user_by_req(request) user = UserID.from_string(user_id) state = {} @@ -77,7 +77,7 @@ class PresenceListRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_GET(self, request, user_id): - auth_user, client = yield self.auth.get_user_by_req(request) + auth_user, _ = yield self.auth.get_user_by_req(request) user = UserID.from_string(user_id) if not self.hs.is_mine(user): @@ -97,7 +97,7 @@ class PresenceListRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_POST(self, request, user_id): - auth_user, client = yield self.auth.get_user_by_req(request) + auth_user, _ = yield self.auth.get_user_by_req(request) user = UserID.from_string(user_id) if not self.hs.is_mine(user): diff --git a/synapse/rest/client/v1/profile.py b/synapse/rest/client/v1/profile.py index 1e77eb49cf..fdde88a60d 100644 --- a/synapse/rest/client/v1/profile.py +++ b/synapse/rest/client/v1/profile.py @@ -37,7 +37,7 @@ class ProfileDisplaynameRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_PUT(self, request, user_id): - auth_user, client = yield self.auth.get_user_by_req(request) + auth_user, _ = yield self.auth.get_user_by_req(request) user = UserID.from_string(user_id) try: @@ -70,7 +70,7 @@ class ProfileAvatarURLRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_PUT(self, request, user_id): - auth_user, client = yield self.auth.get_user_by_req(request) + auth_user, _ = yield self.auth.get_user_by_req(request) user = UserID.from_string(user_id) try: diff --git a/synapse/rest/client/v1/pusher.py b/synapse/rest/client/v1/pusher.py index c83287c028..3aabc93b8b 100644 --- a/synapse/rest/client/v1/pusher.py +++ b/synapse/rest/client/v1/pusher.py @@ -27,7 +27,7 @@ class PusherRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_POST(self, request): - user, client = yield self.auth.get_user_by_req(request) + user, token_id = yield self.auth.get_user_by_req(request) content = _parse_json(request) @@ -65,7 +65,7 @@ class PusherRestServlet(ClientV1RestServlet): try: yield pusher_pool.add_pusher( user_name=user.to_string(), - access_token=client.token_id, + access_token=token_id, profile_tag=content['profile_tag'], kind=content['kind'], app_id=content['app_id'], diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index b4a70cba99..c9c27dd5a0 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -62,7 +62,7 @@ class RoomCreateRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_POST(self, request): - auth_user, client = yield self.auth.get_user_by_req(request) + auth_user, _ = yield self.auth.get_user_by_req(request) room_config = self.get_room_config(request) info = yield self.make_room(room_config, auth_user, None) @@ -125,7 +125,7 @@ class RoomStateEventRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_GET(self, request, room_id, event_type, state_key): - user, client = yield self.auth.get_user_by_req(request) + user, _ = yield self.auth.get_user_by_req(request) msg_handler = self.handlers.message_handler data = yield msg_handler.get_room_data( @@ -143,7 +143,7 @@ class RoomStateEventRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_PUT(self, request, room_id, event_type, state_key, txn_id=None): - user, client = yield self.auth.get_user_by_req(request) + user, token_id = yield self.auth.get_user_by_req(request) content = _parse_json(request) @@ -159,7 +159,7 @@ class RoomStateEventRestServlet(ClientV1RestServlet): msg_handler = self.handlers.message_handler yield msg_handler.create_and_send_event( - event_dict, client=client, txn_id=txn_id, + event_dict, token_id=token_id, txn_id=txn_id, ) defer.returnValue((200, {})) @@ -175,7 +175,7 @@ class RoomSendEventRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_POST(self, request, room_id, event_type, txn_id=None): - user, client = yield self.auth.get_user_by_req(request) + user, token_id = yield self.auth.get_user_by_req(request) content = _parse_json(request) msg_handler = self.handlers.message_handler @@ -186,7 +186,7 @@ class RoomSendEventRestServlet(ClientV1RestServlet): "room_id": room_id, "sender": user.to_string(), }, - client=client, + token_id=token_id, txn_id=txn_id, ) @@ -220,7 +220,7 @@ class JoinRoomAliasServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_POST(self, request, room_identifier, txn_id=None): - user, client = yield self.auth.get_user_by_req(request) + user, token_id = yield self.auth.get_user_by_req(request) # the identifier could be a room alias or a room id. Try one then the # other if it fails to parse, without swallowing other valid @@ -250,7 +250,7 @@ class JoinRoomAliasServlet(ClientV1RestServlet): "sender": user.to_string(), "state_key": user.to_string(), }, - client=client, + token_id=token_id, txn_id=txn_id, ) @@ -289,7 +289,7 @@ class RoomMemberListRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_GET(self, request, room_id): # TODO support Pagination stream API (limit/tokens) - user, client = yield self.auth.get_user_by_req(request) + user, _ = yield self.auth.get_user_by_req(request) handler = self.handlers.room_member_handler members = yield handler.get_room_members_as_pagination_chunk( room_id=room_id, @@ -317,7 +317,7 @@ class RoomMessageListRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_GET(self, request, room_id): - user, client = yield self.auth.get_user_by_req(request) + user, _ = yield self.auth.get_user_by_req(request) pagination_config = PaginationConfig.from_request( request, default_limit=10, ) @@ -341,7 +341,7 @@ class RoomStateRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_GET(self, request, room_id): - user, client = yield self.auth.get_user_by_req(request) + user, _ = yield self.auth.get_user_by_req(request) handler = self.handlers.message_handler # Get all the current state for this room events = yield handler.get_state_events( @@ -357,7 +357,7 @@ class RoomInitialSyncRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_GET(self, request, room_id): - user, client = yield self.auth.get_user_by_req(request) + user, _ = yield self.auth.get_user_by_req(request) pagination_config = PaginationConfig.from_request(request) content = yield self.handlers.message_handler.room_initial_sync( room_id=room_id, @@ -402,7 +402,7 @@ class RoomMembershipRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_POST(self, request, room_id, membership_action, txn_id=None): - user, client = yield self.auth.get_user_by_req(request) + user, token_id = yield self.auth.get_user_by_req(request) content = _parse_json(request) @@ -427,7 +427,7 @@ class RoomMembershipRestServlet(ClientV1RestServlet): "sender": user.to_string(), "state_key": state_key, }, - client=client, + token_id=token_id, txn_id=txn_id, ) @@ -457,7 +457,7 @@ class RoomRedactEventRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_POST(self, request, room_id, event_id, txn_id=None): - user, client = yield self.auth.get_user_by_req(request) + user, token_id = yield self.auth.get_user_by_req(request) content = _parse_json(request) msg_handler = self.handlers.message_handler @@ -469,7 +469,7 @@ class RoomRedactEventRestServlet(ClientV1RestServlet): "sender": user.to_string(), "redacts": event_id, }, - client=client, + token_id=token_id, txn_id=txn_id, ) @@ -497,7 +497,7 @@ class RoomTypingRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_PUT(self, request, room_id, user_id): - auth_user, client = yield self.auth.get_user_by_req(request) + auth_user, _ = yield self.auth.get_user_by_req(request) room_id = urllib.unquote(room_id) target_user = UserID.from_string(urllib.unquote(user_id)) diff --git a/synapse/rest/client/v1/voip.py b/synapse/rest/client/v1/voip.py index 11d08fbced..4ae2d81b70 100644 --- a/synapse/rest/client/v1/voip.py +++ b/synapse/rest/client/v1/voip.py @@ -28,7 +28,7 @@ class VoipRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_GET(self, request): - auth_user, client = yield self.auth.get_user_by_req(request) + auth_user, _ = yield self.auth.get_user_by_req(request) turnUris = self.hs.config.turn_uris turnSecret = self.hs.config.turn_shared_secret diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py index 522a312c9e..b5edffdb60 100644 --- a/synapse/rest/client/v2_alpha/account.py +++ b/synapse/rest/client/v2_alpha/account.py @@ -55,7 +55,7 @@ class PasswordRestServlet(RestServlet): if LoginType.PASSWORD in result: # if using password, they should also be logged in - auth_user, client = yield self.auth.get_user_by_req(request) + auth_user, _ = yield self.auth.get_user_by_req(request) if auth_user.to_string() != result[LoginType.PASSWORD]: raise LoginError(400, "", Codes.UNKNOWN) user_id = auth_user.to_string() @@ -119,7 +119,7 @@ class ThreepidRestServlet(RestServlet): raise SynapseError(400, "Missing param", Codes.MISSING_PARAM) threePidCreds = body['threePidCreds'] - auth_user, client = yield self.auth.get_user_by_req(request) + auth_user, _ = yield self.auth.get_user_by_req(request) threepid = yield self.identity_handler.threepid_from_creds(threePidCreds) diff --git a/synapse/rest/client/v2_alpha/filter.py b/synapse/rest/client/v2_alpha/filter.py index 703250cea8..f8f91b63f5 100644 --- a/synapse/rest/client/v2_alpha/filter.py +++ b/synapse/rest/client/v2_alpha/filter.py @@ -40,7 +40,7 @@ class GetFilterRestServlet(RestServlet): @defer.inlineCallbacks def on_GET(self, request, user_id, filter_id): target_user = UserID.from_string(user_id) - auth_user, client = yield self.auth.get_user_by_req(request) + auth_user, _ = yield self.auth.get_user_by_req(request) if target_user != auth_user: raise AuthError(403, "Cannot get filters for other users") @@ -76,7 +76,7 @@ class CreateFilterRestServlet(RestServlet): @defer.inlineCallbacks def on_POST(self, request, user_id): target_user = UserID.from_string(user_id) - auth_user, client = yield self.auth.get_user_by_req(request) + auth_user, _ = yield self.auth.get_user_by_req(request) if target_user != auth_user: raise AuthError(403, "Cannot create filters for other users") diff --git a/synapse/rest/client/v2_alpha/keys.py b/synapse/rest/client/v2_alpha/keys.py index 718928eedd..ec1145454f 100644 --- a/synapse/rest/client/v2_alpha/keys.py +++ b/synapse/rest/client/v2_alpha/keys.py @@ -63,7 +63,7 @@ class KeyUploadServlet(RestServlet): @defer.inlineCallbacks def on_POST(self, request, device_id): - auth_user, client_info = yield self.auth.get_user_by_req(request) + auth_user, _ = yield self.auth.get_user_by_req(request) user_id = auth_user.to_string() # TODO: Check that the device_id matches that in the authentication # or derive the device_id from the authentication instead. @@ -108,7 +108,7 @@ class KeyUploadServlet(RestServlet): @defer.inlineCallbacks def on_GET(self, request, device_id): - auth_user, client_info = yield self.auth.get_user_by_req(request) + auth_user, _ = yield self.auth.get_user_by_req(request) user_id = auth_user.to_string() result = yield self.store.count_e2e_one_time_keys(user_id, device_id) @@ -180,7 +180,7 @@ class KeyQueryServlet(RestServlet): @defer.inlineCallbacks def on_GET(self, request, user_id, device_id): - auth_user, client_info = yield self.auth.get_user_by_req(request) + auth_user, _ = yield self.auth.get_user_by_req(request) auth_user_id = auth_user.to_string() user_id = user_id if user_id else auth_user_id device_ids = [device_id] if device_id else [] diff --git a/synapse/rest/client/v2_alpha/receipts.py b/synapse/rest/client/v2_alpha/receipts.py index 40406e2ede..52e99f54d5 100644 --- a/synapse/rest/client/v2_alpha/receipts.py +++ b/synapse/rest/client/v2_alpha/receipts.py @@ -39,7 +39,7 @@ class ReceiptRestServlet(RestServlet): @defer.inlineCallbacks def on_POST(self, request, room_id, receipt_type, event_id): - user, client = yield self.auth.get_user_by_req(request) + user, _ = yield self.auth.get_user_by_req(request) yield self.receipts_handler.received_client_receipt( room_id, diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py index f2fd0b9f32..83a257b969 100644 --- a/synapse/rest/client/v2_alpha/sync.py +++ b/synapse/rest/client/v2_alpha/sync.py @@ -87,7 +87,7 @@ class SyncRestServlet(RestServlet): @defer.inlineCallbacks def on_GET(self, request): - user, client = yield self.auth.get_user_by_req(request) + user, _ = yield self.auth.get_user_by_req(request) timeout = parse_integer(request, "timeout", default=0) limit = parse_integer(request, "limit", required=True) diff --git a/synapse/rest/media/v0/content_repository.py b/synapse/rest/media/v0/content_repository.py index e77a20fb2e..c28dc86cd7 100644 --- a/synapse/rest/media/v0/content_repository.py +++ b/synapse/rest/media/v0/content_repository.py @@ -66,7 +66,7 @@ class ContentRepoResource(resource.Resource): @defer.inlineCallbacks def map_request_to_name(self, request): # auth the user - auth_user, client = yield self.auth.get_user_by_req(request) + auth_user, _ = yield self.auth.get_user_by_req(request) # namespace all file uploads on the user prefix = base64.urlsafe_b64encode( diff --git a/synapse/rest/media/v1/upload_resource.py b/synapse/rest/media/v1/upload_resource.py index cdd1d44e07..439d5a30a8 100644 --- a/synapse/rest/media/v1/upload_resource.py +++ b/synapse/rest/media/v1/upload_resource.py @@ -70,7 +70,7 @@ class UploadResource(BaseMediaResource): @request_handler @defer.inlineCallbacks def _async_render_POST(self, request): - auth_user, client = yield self.auth.get_user_by_req(request) + auth_user, _ = yield self.auth.get_user_by_req(request) # TODO: The checks here are a bit late. The content will have # already been uploaded to a tmp file at this point content_length = request.getHeader("Content-Length") diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index 53673b3bf5..77cb1dbd81 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -94,9 +94,9 @@ class DataStore(RoomMemberStore, RoomStore, ) @defer.inlineCallbacks - def insert_client_ip(self, user, access_token, device_id, ip, user_agent): + def insert_client_ip(self, user, access_token, ip, user_agent): now = int(self._clock.time_msec()) - key = (user.to_string(), access_token, device_id, ip) + key = (user.to_string(), access_token, ip) try: last_seen = self.client_ip_last_seen.get(key) @@ -120,7 +120,6 @@ class DataStore(RoomMemberStore, RoomStore, "user_agent": user_agent, }, values={ - "device_id": device_id, "last_seen": now, }, desc="insert_client_ip", @@ -132,7 +131,7 @@ class DataStore(RoomMemberStore, RoomStore, table="user_ips", keyvalues={"user_id": user.to_string()}, retcols=[ - "device_id", "access_token", "ip", "user_agent", "last_seen" + "access_token", "ip", "user_agent", "last_seen" ], desc="get_user_ip_and_agents", ) diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py index f632306688..240d14c4d0 100644 --- a/synapse/storage/registration.py +++ b/synapse/storage/registration.py @@ -163,7 +163,7 @@ class RegistrationStore(SQLBaseStore): Args: token (str): The access token of a user. Returns: - dict: Including the name (user_id), device_id and whether they are + dict: Including the name (user_id) and whether they are an admin. Raises: StoreError if no user was found. @@ -228,8 +228,7 @@ class RegistrationStore(SQLBaseStore): def _query_for_auth(self, txn, token): sql = ( - "SELECT users.name, users.admin," - " access_tokens.device_id, access_tokens.id as token_id" + "SELECT users.name, users.admin, access_tokens.id as token_id" " FROM users" " INNER JOIN access_tokens on users.name = access_tokens.user_id" " WHERE token = ?" diff --git a/synapse/types.py b/synapse/types.py index e190374cbd..9cffc33d27 100644 --- a/synapse/types.py +++ b/synapse/types.py @@ -209,7 +209,3 @@ class RoomStreamToken(namedtuple("_StreamToken", "topological stream")): return "t%d-%d" % (self.topological, self.stream) else: return "s%d" % (self.stream,) - - -# token_id is the primary key ID of the access token, not the access token itself. -ClientInfo = namedtuple("ClientInfo", ("device_id", "token_id")) diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py index 3343c635cc..777eb0395e 100644 --- a/tests/api/test_auth.py +++ b/tests/api/test_auth.py @@ -40,7 +40,6 @@ class AuthTestCase(unittest.TestCase): self.store.get_app_service_by_token = Mock(return_value=None) user_info = { "name": self.test_user, - "device_id": "nothing", "token_id": "ditto", "admin": False } @@ -49,7 +48,7 @@ class AuthTestCase(unittest.TestCase): request = Mock(args={}) request.args["access_token"] = [self.test_token] request.requestHeaders.getRawHeaders = Mock(return_value=[""]) - (user, info) = yield self.auth.get_user_by_req(request) + (user, _) = yield self.auth.get_user_by_req(request) self.assertEquals(user.to_string(), self.test_user) def test_get_user_by_req_user_bad_token(self): @@ -66,7 +65,6 @@ class AuthTestCase(unittest.TestCase): self.store.get_app_service_by_token = Mock(return_value=None) user_info = { "name": self.test_user, - "device_id": "nothing", "token_id": "ditto", "admin": False } @@ -86,7 +84,7 @@ class AuthTestCase(unittest.TestCase): request = Mock(args={}) request.args["access_token"] = [self.test_token] request.requestHeaders.getRawHeaders = Mock(return_value=[""]) - (user, info) = yield self.auth.get_user_by_req(request) + (user, _) = yield self.auth.get_user_by_req(request) self.assertEquals(user.to_string(), self.test_user) def test_get_user_by_req_appservice_bad_token(self): @@ -121,7 +119,7 @@ class AuthTestCase(unittest.TestCase): request.args["access_token"] = [self.test_token] request.args["user_id"] = [masquerading_user_id] request.requestHeaders.getRawHeaders = Mock(return_value=[""]) - (user, info) = yield self.auth.get_user_by_req(request) + (user, _) = yield self.auth.get_user_by_req(request) self.assertEquals(user.to_string(), masquerading_user_id) def test_get_user_by_req_appservice_valid_token_bad_user_id(self): diff --git a/tests/rest/client/v1/test_presence.py b/tests/rest/client/v1/test_presence.py index 0b78a82a66..4039a86d85 100644 --- a/tests/rest/client/v1/test_presence.py +++ b/tests/rest/client/v1/test_presence.py @@ -74,7 +74,6 @@ class PresenceStateTestCase(unittest.TestCase): return { "user": UserID.from_string(myid), "admin": False, - "device_id": None, "token_id": 1, } @@ -163,7 +162,6 @@ class PresenceListTestCase(unittest.TestCase): return { "user": UserID.from_string(myid), "admin": False, - "device_id": None, "token_id": 1, } diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py index 2e55cc08a1..dd1e67e0f9 100644 --- a/tests/rest/client/v1/test_rooms.py +++ b/tests/rest/client/v1/test_rooms.py @@ -58,7 +58,6 @@ class RoomPermissionsTestCase(RestTestCase): return { "user": UserID.from_string(self.auth_user_id), "admin": False, - "device_id": None, "token_id": 1, } hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token @@ -445,7 +444,6 @@ class RoomsMemberListTestCase(RestTestCase): return { "user": UserID.from_string(self.auth_user_id), "admin": False, - "device_id": None, "token_id": 1, } hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token @@ -525,7 +523,6 @@ class RoomsCreateTestCase(RestTestCase): return { "user": UserID.from_string(self.auth_user_id), "admin": False, - "device_id": None, "token_id": 1, } hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token @@ -618,7 +615,6 @@ class RoomTopicTestCase(RestTestCase): return { "user": UserID.from_string(self.auth_user_id), "admin": False, - "device_id": None, "token_id": 1, } @@ -725,7 +721,6 @@ class RoomMemberStateTestCase(RestTestCase): return { "user": UserID.from_string(self.auth_user_id), "admin": False, - "device_id": None, "token_id": 1, } hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token @@ -852,7 +847,6 @@ class RoomMessagesTestCase(RestTestCase): return { "user": UserID.from_string(self.auth_user_id), "admin": False, - "device_id": None, "token_id": 1, } hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token @@ -949,7 +943,6 @@ class RoomInitialSyncTestCase(RestTestCase): return { "user": UserID.from_string(self.auth_user_id), "admin": False, - "device_id": None, "token_id": 1, } hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token diff --git a/tests/rest/client/v1/test_typing.py b/tests/rest/client/v1/test_typing.py index dc8bbaaf0e..0f70ce81dc 100644 --- a/tests/rest/client/v1/test_typing.py +++ b/tests/rest/client/v1/test_typing.py @@ -65,7 +65,6 @@ class RoomTypingTestCase(RestTestCase): return { "user": UserID.from_string(self.auth_user_id), "admin": False, - "device_id": None, "token_id": 1, } diff --git a/tests/rest/client/v2_alpha/__init__.py b/tests/rest/client/v2_alpha/__init__.py index 15568b36cd..badb59f080 100644 --- a/tests/rest/client/v2_alpha/__init__.py +++ b/tests/rest/client/v2_alpha/__init__.py @@ -47,7 +47,6 @@ class V2AlphaRestTestCase(unittest.TestCase): return { "user": UserID.from_string(self.USER_ID), "admin": False, - "device_id": None, "token_id": 1, } hs.get_auth().get_user_by_access_token = _get_user_by_access_token diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py index a4f929796a..54fe10d58f 100644 --- a/tests/storage/test_registration.py +++ b/tests/storage/test_registration.py @@ -54,7 +54,6 @@ class RegistrationStoreTestCase(unittest.TestCase): self.assertDictContainsSubset( { "admin": 0, - "device_id": None, "name": self.user_id, }, result @@ -72,7 +71,6 @@ class RegistrationStoreTestCase(unittest.TestCase): self.assertDictContainsSubset( { "admin": 0, - "device_id": None, "name": self.user_id, }, result diff --git a/tests/utils.py b/tests/utils.py index d0fba2252d..ff560ef342 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -282,7 +282,6 @@ class MemoryDataStore(object): return { "name": self.tokens_to_users[token], "admin": 0, - "device_id": None, } except: raise StoreError(400, "User does not exist.") @@ -380,7 +379,7 @@ class MemoryDataStore(object): def get_ops_levels(self, room_id): return defer.succeed((5, 5, 5)) - def insert_client_ip(self, user, device_id, access_token, ip, user_agent): + def insert_client_ip(self, user, access_token, ip, user_agent): return defer.succeed(None) -- cgit 1.5.1 From a9d8bd95e722e24c7ddd6b14a3714c1b2f737d4d Mon Sep 17 00:00:00 2001 From: Daniel Wagner-Hall Date: Tue, 25 Aug 2015 16:29:39 +0100 Subject: Stop looking up "admin", which we never read --- synapse/api/auth.py | 4 +--- synapse/storage/registration.py | 5 ++--- tests/api/test_auth.py | 2 -- tests/rest/client/v1/test_presence.py | 2 -- tests/rest/client/v1/test_rooms.py | 7 ------- tests/rest/client/v1/test_typing.py | 1 - tests/rest/client/v2_alpha/__init__.py | 1 - tests/storage/test_registration.py | 6 ++---- tests/utils.py | 1 - 9 files changed, 5 insertions(+), 24 deletions(-) (limited to 'tests/utils.py') diff --git a/synapse/api/auth.py b/synapse/api/auth.py index b41e34e658..65ee1452ce 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -392,8 +392,7 @@ class Auth(object): Args: token (str): The access token to get the user by. Returns: - dict : dict that includes the user and whether the - user is a server admin. + dict : dict that includes the user and the ID of their access token. Raises: AuthError if no user by that token exists or the token is invalid. """ @@ -404,7 +403,6 @@ class Auth(object): errcode=Codes.UNKNOWN_TOKEN ) user_info = { - "admin": bool(ret.get("admin", False)), "user": UserID.from_string(ret.get("name")), "token_id": ret.get("token_id", None), } diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py index 240d14c4d0..a2d0f7c4b1 100644 --- a/synapse/storage/registration.py +++ b/synapse/storage/registration.py @@ -163,8 +163,7 @@ class RegistrationStore(SQLBaseStore): Args: token (str): The access token of a user. Returns: - dict: Including the name (user_id) and whether they are - an admin. + dict: Including the name (user_id) and the ID of their access token. Raises: StoreError if no user was found. """ @@ -228,7 +227,7 @@ class RegistrationStore(SQLBaseStore): def _query_for_auth(self, txn, token): sql = ( - "SELECT users.name, users.admin, access_tokens.id as token_id" + "SELECT users.name, access_tokens.id as token_id" " FROM users" " INNER JOIN access_tokens on users.name = access_tokens.user_id" " WHERE token = ?" diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py index 777eb0395e..22fc804331 100644 --- a/tests/api/test_auth.py +++ b/tests/api/test_auth.py @@ -41,7 +41,6 @@ class AuthTestCase(unittest.TestCase): user_info = { "name": self.test_user, "token_id": "ditto", - "admin": False } self.store.get_user_by_access_token = Mock(return_value=user_info) @@ -66,7 +65,6 @@ class AuthTestCase(unittest.TestCase): user_info = { "name": self.test_user, "token_id": "ditto", - "admin": False } self.store.get_user_by_access_token = Mock(return_value=user_info) diff --git a/tests/rest/client/v1/test_presence.py b/tests/rest/client/v1/test_presence.py index 4039a86d85..91547bdd06 100644 --- a/tests/rest/client/v1/test_presence.py +++ b/tests/rest/client/v1/test_presence.py @@ -73,7 +73,6 @@ class PresenceStateTestCase(unittest.TestCase): def _get_user_by_access_token(token=None): return { "user": UserID.from_string(myid), - "admin": False, "token_id": 1, } @@ -161,7 +160,6 @@ class PresenceListTestCase(unittest.TestCase): def _get_user_by_access_token(token=None): return { "user": UserID.from_string(myid), - "admin": False, "token_id": 1, } diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py index dd1e67e0f9..34ab47d02e 100644 --- a/tests/rest/client/v1/test_rooms.py +++ b/tests/rest/client/v1/test_rooms.py @@ -57,7 +57,6 @@ class RoomPermissionsTestCase(RestTestCase): def _get_user_by_access_token(token=None): return { "user": UserID.from_string(self.auth_user_id), - "admin": False, "token_id": 1, } hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token @@ -443,7 +442,6 @@ class RoomsMemberListTestCase(RestTestCase): def _get_user_by_access_token(token=None): return { "user": UserID.from_string(self.auth_user_id), - "admin": False, "token_id": 1, } hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token @@ -522,7 +520,6 @@ class RoomsCreateTestCase(RestTestCase): def _get_user_by_access_token(token=None): return { "user": UserID.from_string(self.auth_user_id), - "admin": False, "token_id": 1, } hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token @@ -614,7 +611,6 @@ class RoomTopicTestCase(RestTestCase): def _get_user_by_access_token(token=None): return { "user": UserID.from_string(self.auth_user_id), - "admin": False, "token_id": 1, } @@ -720,7 +716,6 @@ class RoomMemberStateTestCase(RestTestCase): def _get_user_by_access_token(token=None): return { "user": UserID.from_string(self.auth_user_id), - "admin": False, "token_id": 1, } hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token @@ -846,7 +841,6 @@ class RoomMessagesTestCase(RestTestCase): def _get_user_by_access_token(token=None): return { "user": UserID.from_string(self.auth_user_id), - "admin": False, "token_id": 1, } hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token @@ -942,7 +936,6 @@ class RoomInitialSyncTestCase(RestTestCase): def _get_user_by_access_token(token=None): return { "user": UserID.from_string(self.auth_user_id), - "admin": False, "token_id": 1, } hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token diff --git a/tests/rest/client/v1/test_typing.py b/tests/rest/client/v1/test_typing.py index 0f70ce81dc..1c4519406d 100644 --- a/tests/rest/client/v1/test_typing.py +++ b/tests/rest/client/v1/test_typing.py @@ -64,7 +64,6 @@ class RoomTypingTestCase(RestTestCase): def _get_user_by_access_token(token=None): return { "user": UserID.from_string(self.auth_user_id), - "admin": False, "token_id": 1, } diff --git a/tests/rest/client/v2_alpha/__init__.py b/tests/rest/client/v2_alpha/__init__.py index badb59f080..ef972a53aa 100644 --- a/tests/rest/client/v2_alpha/__init__.py +++ b/tests/rest/client/v2_alpha/__init__.py @@ -46,7 +46,6 @@ class V2AlphaRestTestCase(unittest.TestCase): def _get_user_by_access_token(token=None): return { "user": UserID.from_string(self.USER_ID), - "admin": False, "token_id": 1, } hs.get_auth().get_user_by_access_token = _get_user_by_access_token diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py index 54fe10d58f..0cce6c37df 100644 --- a/tests/storage/test_registration.py +++ b/tests/storage/test_registration.py @@ -53,8 +53,7 @@ class RegistrationStoreTestCase(unittest.TestCase): self.assertDictContainsSubset( { - "admin": 0, - "name": self.user_id, + "name": self.user_id, }, result ) @@ -70,8 +69,7 @@ class RegistrationStoreTestCase(unittest.TestCase): self.assertDictContainsSubset( { - "admin": 0, - "name": self.user_id, + "name": self.user_id, }, result ) diff --git a/tests/utils.py b/tests/utils.py index ff560ef342..3766a994f2 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -281,7 +281,6 @@ class MemoryDataStore(object): try: return { "name": self.tokens_to_users[token], - "admin": 0, } except: raise StoreError(400, "User does not exist.") -- cgit 1.5.1 From 3063383547529a542b48f416d64fd98eaf6a2f60 Mon Sep 17 00:00:00 2001 From: Daniel Wagner-Hall Date: Wed, 26 Aug 2015 15:59:32 +0100 Subject: Swap out bcrypt for md5 in tests This reduces our ~8 second sequential test time down to ~7 seconds --- synapse/handlers/auth.py | 27 +++++++++++++++++++++++++-- synapse/handlers/register.py | 2 +- tests/utils.py | 13 +++++++++++++ 3 files changed, 39 insertions(+), 3 deletions(-) (limited to 'tests/utils.py') diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 1ab19cd1a6..59f687e0f1 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -324,7 +324,7 @@ class AuthHandler(BaseHandler): def _check_password(self, user_id, password, stored_hash): """Checks that user_id has passed password, raises LoginError if not.""" - if not bcrypt.checkpw(password, stored_hash): + if not self.validate_hash(password, stored_hash): logger.warn("Failed password login for user %s", user_id) raise LoginError(403, "", errcode=Codes.FORBIDDEN) @@ -369,7 +369,7 @@ class AuthHandler(BaseHandler): @defer.inlineCallbacks def set_password(self, user_id, newpassword): - password_hash = bcrypt.hashpw(newpassword, bcrypt.gensalt()) + password_hash = self.hash(newpassword) yield self.store.user_set_password_hash(user_id, password_hash) yield self.store.user_delete_access_tokens(user_id) @@ -391,3 +391,26 @@ class AuthHandler(BaseHandler): def _remove_session(self, session): logger.debug("Removing session %s", session) del self.sessions[session["id"]] + + def hash(self, password): + """Computes a secure hash of password. + + Args: + password (str): Password to hash. + + Returns: + Hashed password (str). + """ + return bcrypt.hashpw(password, bcrypt.gensalt()) + + def validate_hash(self, password, stored_hash): + """Validates that self.hash(password) == stored_hash. + + Args: + password (str): Password to hash. + stored_hash (str): Expected hash value. + + Returns: + Whether self.hash(password) == stored_hash (bool). + """ + return bcrypt.checkpw(password, stored_hash) diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 56d125f753..855bb58522 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -82,7 +82,7 @@ class RegistrationHandler(BaseHandler): yield run_on_reactor() password_hash = None if password: - password_hash = bcrypt.hashpw(password, bcrypt.gensalt()) + password_hash = self.auth_handler().hash(password) if localpart: yield self.check_username(localpart) diff --git a/tests/utils.py b/tests/utils.py index 3766a994f2..dd19a16fc7 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -27,6 +27,7 @@ from twisted.enterprise.adbapi import ConnectionPool from collections import namedtuple from mock import patch, Mock +import hashlib import urllib import urlparse @@ -67,6 +68,18 @@ def setup_test_homeserver(name="test", datastore=None, config=None, **kargs): **kargs ) + # bcrypt is far too slow to be doing in unit tests + def swap_out_hash_for_testing(old_build_handlers): + def build_handlers(): + handlers = old_build_handlers() + auth_handler = handlers.auth_handler + auth_handler.hash = lambda p: hashlib.md5(p).hexdigest() + auth_handler.validate_hash = lambda p, h: hashlib.md5(p).hexdigest() == h + return handlers + return build_handlers + + hs.build_handlers = swap_out_hash_for_testing(hs.build_handlers) + defer.returnValue(hs) -- cgit 1.5.1 From ec398af41c4d276abb02279efbcbb0aa08a4cbc8 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 13 Oct 2015 11:41:04 +0100 Subject: Expose error more nicely --- synapse/app/homeserver.py | 5 +- synapse/storage/__init__.py | 3 - synapse/storage/_schema_prepare.py | 395 ------------------------------------ synapse/storage/engines/postgres.py | 2 +- synapse/storage/engines/sqlite3.py | 2 +- synapse/storage/schema_prepare.py | 395 ++++++++++++++++++++++++++++++++++++ tests/utils.py | 2 +- 7 files changed, 400 insertions(+), 404 deletions(-) delete mode 100644 synapse/storage/_schema_prepare.py create mode 100644 synapse/storage/schema_prepare.py (limited to 'tests/utils.py') diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index 190b03e2f7..b284d07cf0 100755 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -35,9 +35,8 @@ if __name__ == '__main__': from synapse.storage.engines import create_engine, IncorrectDatabaseSetup -from synapse.storage import ( - are_all_users_on_domain, UpgradeDatabaseException, -) +from synapse.storage import are_all_users_on_domain +from synapse.storage.schema_prepare import UpgradeDatabaseException from synapse.server import HomeServer diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index 4be629bff8..48a0633746 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -41,9 +41,6 @@ from .end_to_end_keys import EndToEndKeyStore from .receipts import ReceiptsStore -from ._schema_prepare import UpgradeDatabaseException - -__all__ = [UpgradeDatabaseException] import logging diff --git a/synapse/storage/_schema_prepare.py b/synapse/storage/_schema_prepare.py deleted file mode 100644 index 1ddf55be4d..0000000000 --- a/synapse/storage/_schema_prepare.py +++ /dev/null @@ -1,395 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2014, 2015 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import fnmatch -import imp -import logging -import os -import re - - -logger = logging.getLogger(__name__) - - -# Remember to update this number every time a change is made to database -# schema files, so the users will be informed on server restarts. -SCHEMA_VERSION = 24 - -dir_path = os.path.abspath(os.path.dirname(__file__)) - - -def read_schema(path): - """ Read the named database schema. - - Args: - path: Path of the database schema. - Returns: - A string containing the database schema. - """ - with open(path) as schema_file: - return schema_file.read() - - -class PrepareDatabaseException(Exception): - pass - - -class UpgradeDatabaseException(PrepareDatabaseException): - pass - - -def prepare_database(db_conn, database_engine): - """Prepares a database for usage. Will either create all necessary tables - or upgrade from an older schema version. - """ - try: - cur = db_conn.cursor() - version_info = _get_or_create_schema_state(cur, database_engine) - - if version_info: - user_version, delta_files, upgraded = version_info - _upgrade_existing_database( - cur, user_version, delta_files, upgraded, database_engine - ) - else: - _setup_new_database(cur, database_engine) - - # cur.execute("PRAGMA user_version = %d" % (SCHEMA_VERSION,)) - - cur.close() - db_conn.commit() - except: - db_conn.rollback() - raise - - -def _setup_new_database(cur, database_engine): - """Sets up the database by finding a base set of "full schemas" and then - applying any necessary deltas. - - The "full_schemas" directory has subdirectories named after versions. This - function searches for the highest version less than or equal to - `SCHEMA_VERSION` and executes all .sql files in that directory. - - The function will then apply all deltas for all versions after the base - version. - - Example directory structure: - - schema/ - delta/ - ... - full_schemas/ - 3/ - test.sql - ... - 11/ - foo.sql - bar.sql - ... - - In the example foo.sql and bar.sql would be run, and then any delta files - for versions strictly greater than 11. - """ - current_dir = os.path.join(dir_path, "schema", "full_schemas") - directory_entries = os.listdir(current_dir) - - valid_dirs = [] - pattern = re.compile(r"^\d+(\.sql)?$") - for filename in directory_entries: - match = pattern.match(filename) - abs_path = os.path.join(current_dir, filename) - if match and os.path.isdir(abs_path): - ver = int(match.group(0)) - if ver <= SCHEMA_VERSION: - valid_dirs.append((ver, abs_path)) - else: - logger.warn("Unexpected entry in 'full_schemas': %s", filename) - - if not valid_dirs: - raise PrepareDatabaseException( - "Could not find a suitable base set of full schemas" - ) - - max_current_ver, sql_dir = max(valid_dirs, key=lambda x: x[0]) - - logger.debug("Initialising schema v%d", max_current_ver) - - directory_entries = os.listdir(sql_dir) - - for filename in fnmatch.filter(directory_entries, "*.sql"): - sql_loc = os.path.join(sql_dir, filename) - logger.debug("Applying schema %s", sql_loc) - executescript(cur, sql_loc) - - cur.execute( - database_engine.convert_param_style( - "INSERT INTO schema_version (version, upgraded)" - " VALUES (?,?)" - ), - (max_current_ver, False,) - ) - - _upgrade_existing_database( - cur, - current_version=max_current_ver, - applied_delta_files=[], - upgraded=False, - database_engine=database_engine, - ) - - -def _upgrade_existing_database(cur, current_version, applied_delta_files, - upgraded, database_engine): - """Upgrades an existing database. - - Delta files can either be SQL stored in *.sql files, or python modules - in *.py. - - There can be multiple delta files per version. Synapse will keep track of - which delta files have been applied, and will apply any that haven't been - even if there has been no version bump. This is useful for development - where orthogonal schema changes may happen on separate branches. - - Different delta files for the same version *must* be orthogonal and give - the same result when applied in any order. No guarantees are made on the - order of execution of these scripts. - - This is a no-op of current_version == SCHEMA_VERSION. - - Example directory structure: - - schema/ - delta/ - 11/ - foo.sql - ... - 12/ - foo.sql - bar.py - ... - full_schemas/ - ... - - In the example, if current_version is 11, then foo.sql will be run if and - only if `upgraded` is True. Then `foo.sql` and `bar.py` would be run in - some arbitrary order. - - Args: - cur (Cursor) - current_version (int): The current version of the schema. - applied_delta_files (list): A list of deltas that have already been - applied. - upgraded (bool): Whether the current version was generated by having - applied deltas or from full schema file. If `True` the function - will never apply delta files for the given `current_version`, since - the current_version wasn't generated by applying those delta files. - """ - - if current_version > SCHEMA_VERSION: - raise ValueError( - "Cannot use this database as it is too " + - "new for the server to understand" - ) - - start_ver = current_version - if not upgraded: - start_ver += 1 - - logger.debug("applied_delta_files: %s", applied_delta_files) - - for v in range(start_ver, SCHEMA_VERSION + 1): - logger.debug("Upgrading schema to v%d", v) - - delta_dir = os.path.join(dir_path, "schema", "delta", str(v)) - - try: - directory_entries = os.listdir(delta_dir) - except OSError: - logger.exception("Could not open delta dir for version %d", v) - raise UpgradeDatabaseException( - "Could not open delta dir for version %d" % (v,) - ) - - directory_entries.sort() - for file_name in directory_entries: - relative_path = os.path.join(str(v), file_name) - logger.debug("Found file: %s", relative_path) - if relative_path in applied_delta_files: - continue - - absolute_path = os.path.join( - dir_path, "schema", "delta", relative_path, - ) - root_name, ext = os.path.splitext(file_name) - if ext == ".py": - # This is a python upgrade module. We need to import into some - # package and then execute its `run_upgrade` function. - module_name = "synapse.storage.v%d_%s" % ( - v, root_name - ) - with open(absolute_path) as python_file: - module = imp.load_source( - module_name, absolute_path, python_file - ) - logger.debug("Running script %s", relative_path) - module.run_upgrade(cur, database_engine) - elif ext == ".pyc": - # Sometimes .pyc files turn up anyway even though we've - # disabled their generation; e.g. from distribution package - # installers. Silently skip it - pass - elif ext == ".sql": - # A plain old .sql file, just read and execute it - logger.debug("Applying schema %s", relative_path) - executescript(cur, absolute_path) - else: - # Not a valid delta file. - logger.warn( - "Found directory entry that did not end in .py or" - " .sql: %s", - relative_path, - ) - continue - - # Mark as done. - cur.execute( - database_engine.convert_param_style( - "INSERT INTO applied_schema_deltas (version, file)" - " VALUES (?,?)", - ), - (v, relative_path) - ) - - cur.execute("DELETE FROM schema_version") - cur.execute( - database_engine.convert_param_style( - "INSERT INTO schema_version (version, upgraded)" - " VALUES (?,?)", - ), - (v, True) - ) - - -def get_statements(f): - statement_buffer = "" - in_comment = False # If we're in a /* ... */ style comment - - for line in f: - line = line.strip() - - if in_comment: - # Check if this line contains an end to the comment - comments = line.split("*/", 1) - if len(comments) == 1: - continue - line = comments[1] - in_comment = False - - # Remove inline block comments - line = re.sub(r"/\*.*\*/", " ", line) - - # Does this line start a comment? - comments = line.split("/*", 1) - if len(comments) > 1: - line = comments[0] - in_comment = True - - # Deal with line comments - line = line.split("--", 1)[0] - line = line.split("//", 1)[0] - - # Find *all* semicolons. We need to treat first and last entry - # specially. - statements = line.split(";") - - # We must prepend statement_buffer to the first statement - first_statement = "%s %s" % ( - statement_buffer.strip(), - statements[0].strip() - ) - statements[0] = first_statement - - # Every entry, except the last, is a full statement - for statement in statements[:-1]: - yield statement.strip() - - # The last entry did *not* end in a semicolon, so we store it for the - # next semicolon we find - statement_buffer = statements[-1].strip() - - -def executescript(txn, schema_path): - with open(schema_path, 'r') as f: - for statement in get_statements(f): - txn.execute(statement) - - -def _get_or_create_schema_state(txn, database_engine): - # Bluntly try creating the schema_version tables. - schema_path = os.path.join( - dir_path, "schema", "schema_version.sql", - ) - executescript(txn, schema_path) - - txn.execute("SELECT version, upgraded FROM schema_version") - row = txn.fetchone() - current_version = int(row[0]) if row else None - upgraded = bool(row[1]) if row else None - - if current_version: - txn.execute( - database_engine.convert_param_style( - "SELECT file FROM applied_schema_deltas WHERE version >= ?" - ), - (current_version,) - ) - applied_deltas = [d for d, in txn.fetchall()] - return current_version, applied_deltas, upgraded - - return None - - -def prepare_sqlite3_database(db_conn): - """This function should be called before `prepare_database` on sqlite3 - databases. - - Since we changed the way we store the current schema version and handle - updates to schemas, we need a way to upgrade from the old method to the - new. This only affects sqlite databases since they were the only ones - supported at the time. - """ - with db_conn: - schema_path = os.path.join( - dir_path, "schema", "schema_version.sql", - ) - create_schema = read_schema(schema_path) - db_conn.executescript(create_schema) - - c = db_conn.execute("SELECT * FROM schema_version") - rows = c.fetchall() - c.close() - - if not rows: - c = db_conn.execute("PRAGMA user_version") - row = c.fetchone() - c.close() - - if row and row[0]: - db_conn.execute( - "REPLACE INTO schema_version (version, upgraded)" - " VALUES (?,?)", - (row[0], False) - ) diff --git a/synapse/storage/engines/postgres.py b/synapse/storage/engines/postgres.py index 949396044e..7e45dabf4c 100644 --- a/synapse/storage/engines/postgres.py +++ b/synapse/storage/engines/postgres.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.storage._schema_prepare import prepare_database +from synapse.storage.schema_prepare import prepare_database from ._base import IncorrectDatabaseSetup diff --git a/synapse/storage/engines/sqlite3.py b/synapse/storage/engines/sqlite3.py index a66815ef2d..0eeaa45d19 100644 --- a/synapse/storage/engines/sqlite3.py +++ b/synapse/storage/engines/sqlite3.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.storage._schema_prepare import ( +from synapse.storage.schema_prepare import ( prepare_database, prepare_sqlite3_database ) diff --git a/synapse/storage/schema_prepare.py b/synapse/storage/schema_prepare.py new file mode 100644 index 0000000000..1ddf55be4d --- /dev/null +++ b/synapse/storage/schema_prepare.py @@ -0,0 +1,395 @@ +# -*- coding: utf-8 -*- +# Copyright 2014, 2015 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import fnmatch +import imp +import logging +import os +import re + + +logger = logging.getLogger(__name__) + + +# Remember to update this number every time a change is made to database +# schema files, so the users will be informed on server restarts. +SCHEMA_VERSION = 24 + +dir_path = os.path.abspath(os.path.dirname(__file__)) + + +def read_schema(path): + """ Read the named database schema. + + Args: + path: Path of the database schema. + Returns: + A string containing the database schema. + """ + with open(path) as schema_file: + return schema_file.read() + + +class PrepareDatabaseException(Exception): + pass + + +class UpgradeDatabaseException(PrepareDatabaseException): + pass + + +def prepare_database(db_conn, database_engine): + """Prepares a database for usage. Will either create all necessary tables + or upgrade from an older schema version. + """ + try: + cur = db_conn.cursor() + version_info = _get_or_create_schema_state(cur, database_engine) + + if version_info: + user_version, delta_files, upgraded = version_info + _upgrade_existing_database( + cur, user_version, delta_files, upgraded, database_engine + ) + else: + _setup_new_database(cur, database_engine) + + # cur.execute("PRAGMA user_version = %d" % (SCHEMA_VERSION,)) + + cur.close() + db_conn.commit() + except: + db_conn.rollback() + raise + + +def _setup_new_database(cur, database_engine): + """Sets up the database by finding a base set of "full schemas" and then + applying any necessary deltas. + + The "full_schemas" directory has subdirectories named after versions. This + function searches for the highest version less than or equal to + `SCHEMA_VERSION` and executes all .sql files in that directory. + + The function will then apply all deltas for all versions after the base + version. + + Example directory structure: + + schema/ + delta/ + ... + full_schemas/ + 3/ + test.sql + ... + 11/ + foo.sql + bar.sql + ... + + In the example foo.sql and bar.sql would be run, and then any delta files + for versions strictly greater than 11. + """ + current_dir = os.path.join(dir_path, "schema", "full_schemas") + directory_entries = os.listdir(current_dir) + + valid_dirs = [] + pattern = re.compile(r"^\d+(\.sql)?$") + for filename in directory_entries: + match = pattern.match(filename) + abs_path = os.path.join(current_dir, filename) + if match and os.path.isdir(abs_path): + ver = int(match.group(0)) + if ver <= SCHEMA_VERSION: + valid_dirs.append((ver, abs_path)) + else: + logger.warn("Unexpected entry in 'full_schemas': %s", filename) + + if not valid_dirs: + raise PrepareDatabaseException( + "Could not find a suitable base set of full schemas" + ) + + max_current_ver, sql_dir = max(valid_dirs, key=lambda x: x[0]) + + logger.debug("Initialising schema v%d", max_current_ver) + + directory_entries = os.listdir(sql_dir) + + for filename in fnmatch.filter(directory_entries, "*.sql"): + sql_loc = os.path.join(sql_dir, filename) + logger.debug("Applying schema %s", sql_loc) + executescript(cur, sql_loc) + + cur.execute( + database_engine.convert_param_style( + "INSERT INTO schema_version (version, upgraded)" + " VALUES (?,?)" + ), + (max_current_ver, False,) + ) + + _upgrade_existing_database( + cur, + current_version=max_current_ver, + applied_delta_files=[], + upgraded=False, + database_engine=database_engine, + ) + + +def _upgrade_existing_database(cur, current_version, applied_delta_files, + upgraded, database_engine): + """Upgrades an existing database. + + Delta files can either be SQL stored in *.sql files, or python modules + in *.py. + + There can be multiple delta files per version. Synapse will keep track of + which delta files have been applied, and will apply any that haven't been + even if there has been no version bump. This is useful for development + where orthogonal schema changes may happen on separate branches. + + Different delta files for the same version *must* be orthogonal and give + the same result when applied in any order. No guarantees are made on the + order of execution of these scripts. + + This is a no-op of current_version == SCHEMA_VERSION. + + Example directory structure: + + schema/ + delta/ + 11/ + foo.sql + ... + 12/ + foo.sql + bar.py + ... + full_schemas/ + ... + + In the example, if current_version is 11, then foo.sql will be run if and + only if `upgraded` is True. Then `foo.sql` and `bar.py` would be run in + some arbitrary order. + + Args: + cur (Cursor) + current_version (int): The current version of the schema. + applied_delta_files (list): A list of deltas that have already been + applied. + upgraded (bool): Whether the current version was generated by having + applied deltas or from full schema file. If `True` the function + will never apply delta files for the given `current_version`, since + the current_version wasn't generated by applying those delta files. + """ + + if current_version > SCHEMA_VERSION: + raise ValueError( + "Cannot use this database as it is too " + + "new for the server to understand" + ) + + start_ver = current_version + if not upgraded: + start_ver += 1 + + logger.debug("applied_delta_files: %s", applied_delta_files) + + for v in range(start_ver, SCHEMA_VERSION + 1): + logger.debug("Upgrading schema to v%d", v) + + delta_dir = os.path.join(dir_path, "schema", "delta", str(v)) + + try: + directory_entries = os.listdir(delta_dir) + except OSError: + logger.exception("Could not open delta dir for version %d", v) + raise UpgradeDatabaseException( + "Could not open delta dir for version %d" % (v,) + ) + + directory_entries.sort() + for file_name in directory_entries: + relative_path = os.path.join(str(v), file_name) + logger.debug("Found file: %s", relative_path) + if relative_path in applied_delta_files: + continue + + absolute_path = os.path.join( + dir_path, "schema", "delta", relative_path, + ) + root_name, ext = os.path.splitext(file_name) + if ext == ".py": + # This is a python upgrade module. We need to import into some + # package and then execute its `run_upgrade` function. + module_name = "synapse.storage.v%d_%s" % ( + v, root_name + ) + with open(absolute_path) as python_file: + module = imp.load_source( + module_name, absolute_path, python_file + ) + logger.debug("Running script %s", relative_path) + module.run_upgrade(cur, database_engine) + elif ext == ".pyc": + # Sometimes .pyc files turn up anyway even though we've + # disabled their generation; e.g. from distribution package + # installers. Silently skip it + pass + elif ext == ".sql": + # A plain old .sql file, just read and execute it + logger.debug("Applying schema %s", relative_path) + executescript(cur, absolute_path) + else: + # Not a valid delta file. + logger.warn( + "Found directory entry that did not end in .py or" + " .sql: %s", + relative_path, + ) + continue + + # Mark as done. + cur.execute( + database_engine.convert_param_style( + "INSERT INTO applied_schema_deltas (version, file)" + " VALUES (?,?)", + ), + (v, relative_path) + ) + + cur.execute("DELETE FROM schema_version") + cur.execute( + database_engine.convert_param_style( + "INSERT INTO schema_version (version, upgraded)" + " VALUES (?,?)", + ), + (v, True) + ) + + +def get_statements(f): + statement_buffer = "" + in_comment = False # If we're in a /* ... */ style comment + + for line in f: + line = line.strip() + + if in_comment: + # Check if this line contains an end to the comment + comments = line.split("*/", 1) + if len(comments) == 1: + continue + line = comments[1] + in_comment = False + + # Remove inline block comments + line = re.sub(r"/\*.*\*/", " ", line) + + # Does this line start a comment? + comments = line.split("/*", 1) + if len(comments) > 1: + line = comments[0] + in_comment = True + + # Deal with line comments + line = line.split("--", 1)[0] + line = line.split("//", 1)[0] + + # Find *all* semicolons. We need to treat first and last entry + # specially. + statements = line.split(";") + + # We must prepend statement_buffer to the first statement + first_statement = "%s %s" % ( + statement_buffer.strip(), + statements[0].strip() + ) + statements[0] = first_statement + + # Every entry, except the last, is a full statement + for statement in statements[:-1]: + yield statement.strip() + + # The last entry did *not* end in a semicolon, so we store it for the + # next semicolon we find + statement_buffer = statements[-1].strip() + + +def executescript(txn, schema_path): + with open(schema_path, 'r') as f: + for statement in get_statements(f): + txn.execute(statement) + + +def _get_or_create_schema_state(txn, database_engine): + # Bluntly try creating the schema_version tables. + schema_path = os.path.join( + dir_path, "schema", "schema_version.sql", + ) + executescript(txn, schema_path) + + txn.execute("SELECT version, upgraded FROM schema_version") + row = txn.fetchone() + current_version = int(row[0]) if row else None + upgraded = bool(row[1]) if row else None + + if current_version: + txn.execute( + database_engine.convert_param_style( + "SELECT file FROM applied_schema_deltas WHERE version >= ?" + ), + (current_version,) + ) + applied_deltas = [d for d, in txn.fetchall()] + return current_version, applied_deltas, upgraded + + return None + + +def prepare_sqlite3_database(db_conn): + """This function should be called before `prepare_database` on sqlite3 + databases. + + Since we changed the way we store the current schema version and handle + updates to schemas, we need a way to upgrade from the old method to the + new. This only affects sqlite databases since they were the only ones + supported at the time. + """ + with db_conn: + schema_path = os.path.join( + dir_path, "schema", "schema_version.sql", + ) + create_schema = read_schema(schema_path) + db_conn.executescript(create_schema) + + c = db_conn.execute("SELECT * FROM schema_version") + rows = c.fetchall() + c.close() + + if not rows: + c = db_conn.execute("PRAGMA user_version") + row = c.fetchone() + c.close() + + if row and row[0]: + db_conn.execute( + "REPLACE INTO schema_version (version, upgraded)" + " VALUES (?,?)", + (row[0], False) + ) diff --git a/tests/utils.py b/tests/utils.py index dd19a16fc7..6eb575bd09 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -16,7 +16,7 @@ from synapse.http.server import HttpServer from synapse.api.errors import cs_error, CodeMessageException, StoreError from synapse.api.constants import EventTypes -from synapse.storage import prepare_database +from synapse.storage.schema_prepare import prepare_database from synapse.storage.engines import create_engine from synapse.server import HomeServer -- cgit 1.5.1 From 17c80c8a3d92acca5bda9b0fc7d9898547476563 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 13 Oct 2015 13:56:22 +0100 Subject: rename schema_prepare to prepare_database --- synapse/app/homeserver.py | 2 +- synapse/storage/engines/postgres.py | 2 +- synapse/storage/engines/sqlite3.py | 2 +- synapse/storage/prepare_database.py | 395 ++++++++++++++++++++++++++++++++++++ synapse/storage/schema_prepare.py | 395 ------------------------------------ tests/utils.py | 2 +- 6 files changed, 399 insertions(+), 399 deletions(-) create mode 100644 synapse/storage/prepare_database.py delete mode 100644 synapse/storage/schema_prepare.py (limited to 'tests/utils.py') diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index b284d07cf0..af53acb369 100755 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -36,7 +36,7 @@ if __name__ == '__main__': from synapse.storage.engines import create_engine, IncorrectDatabaseSetup from synapse.storage import are_all_users_on_domain -from synapse.storage.schema_prepare import UpgradeDatabaseException +from synapse.storage.prepare_database import UpgradeDatabaseException from synapse.server import HomeServer diff --git a/synapse/storage/engines/postgres.py b/synapse/storage/engines/postgres.py index 7e45dabf4c..98d66e0a86 100644 --- a/synapse/storage/engines/postgres.py +++ b/synapse/storage/engines/postgres.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.storage.schema_prepare import prepare_database +from synapse.storage.prepare_database import prepare_database from ._base import IncorrectDatabaseSetup diff --git a/synapse/storage/engines/sqlite3.py b/synapse/storage/engines/sqlite3.py index 0eeaa45d19..bad3b5c5ac 100644 --- a/synapse/storage/engines/sqlite3.py +++ b/synapse/storage/engines/sqlite3.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.storage.schema_prepare import ( +from synapse.storage.prepare_database import ( prepare_database, prepare_sqlite3_database ) diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py new file mode 100644 index 0000000000..1ddf55be4d --- /dev/null +++ b/synapse/storage/prepare_database.py @@ -0,0 +1,395 @@ +# -*- coding: utf-8 -*- +# Copyright 2014, 2015 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import fnmatch +import imp +import logging +import os +import re + + +logger = logging.getLogger(__name__) + + +# Remember to update this number every time a change is made to database +# schema files, so the users will be informed on server restarts. +SCHEMA_VERSION = 24 + +dir_path = os.path.abspath(os.path.dirname(__file__)) + + +def read_schema(path): + """ Read the named database schema. + + Args: + path: Path of the database schema. + Returns: + A string containing the database schema. + """ + with open(path) as schema_file: + return schema_file.read() + + +class PrepareDatabaseException(Exception): + pass + + +class UpgradeDatabaseException(PrepareDatabaseException): + pass + + +def prepare_database(db_conn, database_engine): + """Prepares a database for usage. Will either create all necessary tables + or upgrade from an older schema version. + """ + try: + cur = db_conn.cursor() + version_info = _get_or_create_schema_state(cur, database_engine) + + if version_info: + user_version, delta_files, upgraded = version_info + _upgrade_existing_database( + cur, user_version, delta_files, upgraded, database_engine + ) + else: + _setup_new_database(cur, database_engine) + + # cur.execute("PRAGMA user_version = %d" % (SCHEMA_VERSION,)) + + cur.close() + db_conn.commit() + except: + db_conn.rollback() + raise + + +def _setup_new_database(cur, database_engine): + """Sets up the database by finding a base set of "full schemas" and then + applying any necessary deltas. + + The "full_schemas" directory has subdirectories named after versions. This + function searches for the highest version less than or equal to + `SCHEMA_VERSION` and executes all .sql files in that directory. + + The function will then apply all deltas for all versions after the base + version. + + Example directory structure: + + schema/ + delta/ + ... + full_schemas/ + 3/ + test.sql + ... + 11/ + foo.sql + bar.sql + ... + + In the example foo.sql and bar.sql would be run, and then any delta files + for versions strictly greater than 11. + """ + current_dir = os.path.join(dir_path, "schema", "full_schemas") + directory_entries = os.listdir(current_dir) + + valid_dirs = [] + pattern = re.compile(r"^\d+(\.sql)?$") + for filename in directory_entries: + match = pattern.match(filename) + abs_path = os.path.join(current_dir, filename) + if match and os.path.isdir(abs_path): + ver = int(match.group(0)) + if ver <= SCHEMA_VERSION: + valid_dirs.append((ver, abs_path)) + else: + logger.warn("Unexpected entry in 'full_schemas': %s", filename) + + if not valid_dirs: + raise PrepareDatabaseException( + "Could not find a suitable base set of full schemas" + ) + + max_current_ver, sql_dir = max(valid_dirs, key=lambda x: x[0]) + + logger.debug("Initialising schema v%d", max_current_ver) + + directory_entries = os.listdir(sql_dir) + + for filename in fnmatch.filter(directory_entries, "*.sql"): + sql_loc = os.path.join(sql_dir, filename) + logger.debug("Applying schema %s", sql_loc) + executescript(cur, sql_loc) + + cur.execute( + database_engine.convert_param_style( + "INSERT INTO schema_version (version, upgraded)" + " VALUES (?,?)" + ), + (max_current_ver, False,) + ) + + _upgrade_existing_database( + cur, + current_version=max_current_ver, + applied_delta_files=[], + upgraded=False, + database_engine=database_engine, + ) + + +def _upgrade_existing_database(cur, current_version, applied_delta_files, + upgraded, database_engine): + """Upgrades an existing database. + + Delta files can either be SQL stored in *.sql files, or python modules + in *.py. + + There can be multiple delta files per version. Synapse will keep track of + which delta files have been applied, and will apply any that haven't been + even if there has been no version bump. This is useful for development + where orthogonal schema changes may happen on separate branches. + + Different delta files for the same version *must* be orthogonal and give + the same result when applied in any order. No guarantees are made on the + order of execution of these scripts. + + This is a no-op of current_version == SCHEMA_VERSION. + + Example directory structure: + + schema/ + delta/ + 11/ + foo.sql + ... + 12/ + foo.sql + bar.py + ... + full_schemas/ + ... + + In the example, if current_version is 11, then foo.sql will be run if and + only if `upgraded` is True. Then `foo.sql` and `bar.py` would be run in + some arbitrary order. + + Args: + cur (Cursor) + current_version (int): The current version of the schema. + applied_delta_files (list): A list of deltas that have already been + applied. + upgraded (bool): Whether the current version was generated by having + applied deltas or from full schema file. If `True` the function + will never apply delta files for the given `current_version`, since + the current_version wasn't generated by applying those delta files. + """ + + if current_version > SCHEMA_VERSION: + raise ValueError( + "Cannot use this database as it is too " + + "new for the server to understand" + ) + + start_ver = current_version + if not upgraded: + start_ver += 1 + + logger.debug("applied_delta_files: %s", applied_delta_files) + + for v in range(start_ver, SCHEMA_VERSION + 1): + logger.debug("Upgrading schema to v%d", v) + + delta_dir = os.path.join(dir_path, "schema", "delta", str(v)) + + try: + directory_entries = os.listdir(delta_dir) + except OSError: + logger.exception("Could not open delta dir for version %d", v) + raise UpgradeDatabaseException( + "Could not open delta dir for version %d" % (v,) + ) + + directory_entries.sort() + for file_name in directory_entries: + relative_path = os.path.join(str(v), file_name) + logger.debug("Found file: %s", relative_path) + if relative_path in applied_delta_files: + continue + + absolute_path = os.path.join( + dir_path, "schema", "delta", relative_path, + ) + root_name, ext = os.path.splitext(file_name) + if ext == ".py": + # This is a python upgrade module. We need to import into some + # package and then execute its `run_upgrade` function. + module_name = "synapse.storage.v%d_%s" % ( + v, root_name + ) + with open(absolute_path) as python_file: + module = imp.load_source( + module_name, absolute_path, python_file + ) + logger.debug("Running script %s", relative_path) + module.run_upgrade(cur, database_engine) + elif ext == ".pyc": + # Sometimes .pyc files turn up anyway even though we've + # disabled their generation; e.g. from distribution package + # installers. Silently skip it + pass + elif ext == ".sql": + # A plain old .sql file, just read and execute it + logger.debug("Applying schema %s", relative_path) + executescript(cur, absolute_path) + else: + # Not a valid delta file. + logger.warn( + "Found directory entry that did not end in .py or" + " .sql: %s", + relative_path, + ) + continue + + # Mark as done. + cur.execute( + database_engine.convert_param_style( + "INSERT INTO applied_schema_deltas (version, file)" + " VALUES (?,?)", + ), + (v, relative_path) + ) + + cur.execute("DELETE FROM schema_version") + cur.execute( + database_engine.convert_param_style( + "INSERT INTO schema_version (version, upgraded)" + " VALUES (?,?)", + ), + (v, True) + ) + + +def get_statements(f): + statement_buffer = "" + in_comment = False # If we're in a /* ... */ style comment + + for line in f: + line = line.strip() + + if in_comment: + # Check if this line contains an end to the comment + comments = line.split("*/", 1) + if len(comments) == 1: + continue + line = comments[1] + in_comment = False + + # Remove inline block comments + line = re.sub(r"/\*.*\*/", " ", line) + + # Does this line start a comment? + comments = line.split("/*", 1) + if len(comments) > 1: + line = comments[0] + in_comment = True + + # Deal with line comments + line = line.split("--", 1)[0] + line = line.split("//", 1)[0] + + # Find *all* semicolons. We need to treat first and last entry + # specially. + statements = line.split(";") + + # We must prepend statement_buffer to the first statement + first_statement = "%s %s" % ( + statement_buffer.strip(), + statements[0].strip() + ) + statements[0] = first_statement + + # Every entry, except the last, is a full statement + for statement in statements[:-1]: + yield statement.strip() + + # The last entry did *not* end in a semicolon, so we store it for the + # next semicolon we find + statement_buffer = statements[-1].strip() + + +def executescript(txn, schema_path): + with open(schema_path, 'r') as f: + for statement in get_statements(f): + txn.execute(statement) + + +def _get_or_create_schema_state(txn, database_engine): + # Bluntly try creating the schema_version tables. + schema_path = os.path.join( + dir_path, "schema", "schema_version.sql", + ) + executescript(txn, schema_path) + + txn.execute("SELECT version, upgraded FROM schema_version") + row = txn.fetchone() + current_version = int(row[0]) if row else None + upgraded = bool(row[1]) if row else None + + if current_version: + txn.execute( + database_engine.convert_param_style( + "SELECT file FROM applied_schema_deltas WHERE version >= ?" + ), + (current_version,) + ) + applied_deltas = [d for d, in txn.fetchall()] + return current_version, applied_deltas, upgraded + + return None + + +def prepare_sqlite3_database(db_conn): + """This function should be called before `prepare_database` on sqlite3 + databases. + + Since we changed the way we store the current schema version and handle + updates to schemas, we need a way to upgrade from the old method to the + new. This only affects sqlite databases since they were the only ones + supported at the time. + """ + with db_conn: + schema_path = os.path.join( + dir_path, "schema", "schema_version.sql", + ) + create_schema = read_schema(schema_path) + db_conn.executescript(create_schema) + + c = db_conn.execute("SELECT * FROM schema_version") + rows = c.fetchall() + c.close() + + if not rows: + c = db_conn.execute("PRAGMA user_version") + row = c.fetchone() + c.close() + + if row and row[0]: + db_conn.execute( + "REPLACE INTO schema_version (version, upgraded)" + " VALUES (?,?)", + (row[0], False) + ) diff --git a/synapse/storage/schema_prepare.py b/synapse/storage/schema_prepare.py deleted file mode 100644 index 1ddf55be4d..0000000000 --- a/synapse/storage/schema_prepare.py +++ /dev/null @@ -1,395 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2014, 2015 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import fnmatch -import imp -import logging -import os -import re - - -logger = logging.getLogger(__name__) - - -# Remember to update this number every time a change is made to database -# schema files, so the users will be informed on server restarts. -SCHEMA_VERSION = 24 - -dir_path = os.path.abspath(os.path.dirname(__file__)) - - -def read_schema(path): - """ Read the named database schema. - - Args: - path: Path of the database schema. - Returns: - A string containing the database schema. - """ - with open(path) as schema_file: - return schema_file.read() - - -class PrepareDatabaseException(Exception): - pass - - -class UpgradeDatabaseException(PrepareDatabaseException): - pass - - -def prepare_database(db_conn, database_engine): - """Prepares a database for usage. Will either create all necessary tables - or upgrade from an older schema version. - """ - try: - cur = db_conn.cursor() - version_info = _get_or_create_schema_state(cur, database_engine) - - if version_info: - user_version, delta_files, upgraded = version_info - _upgrade_existing_database( - cur, user_version, delta_files, upgraded, database_engine - ) - else: - _setup_new_database(cur, database_engine) - - # cur.execute("PRAGMA user_version = %d" % (SCHEMA_VERSION,)) - - cur.close() - db_conn.commit() - except: - db_conn.rollback() - raise - - -def _setup_new_database(cur, database_engine): - """Sets up the database by finding a base set of "full schemas" and then - applying any necessary deltas. - - The "full_schemas" directory has subdirectories named after versions. This - function searches for the highest version less than or equal to - `SCHEMA_VERSION` and executes all .sql files in that directory. - - The function will then apply all deltas for all versions after the base - version. - - Example directory structure: - - schema/ - delta/ - ... - full_schemas/ - 3/ - test.sql - ... - 11/ - foo.sql - bar.sql - ... - - In the example foo.sql and bar.sql would be run, and then any delta files - for versions strictly greater than 11. - """ - current_dir = os.path.join(dir_path, "schema", "full_schemas") - directory_entries = os.listdir(current_dir) - - valid_dirs = [] - pattern = re.compile(r"^\d+(\.sql)?$") - for filename in directory_entries: - match = pattern.match(filename) - abs_path = os.path.join(current_dir, filename) - if match and os.path.isdir(abs_path): - ver = int(match.group(0)) - if ver <= SCHEMA_VERSION: - valid_dirs.append((ver, abs_path)) - else: - logger.warn("Unexpected entry in 'full_schemas': %s", filename) - - if not valid_dirs: - raise PrepareDatabaseException( - "Could not find a suitable base set of full schemas" - ) - - max_current_ver, sql_dir = max(valid_dirs, key=lambda x: x[0]) - - logger.debug("Initialising schema v%d", max_current_ver) - - directory_entries = os.listdir(sql_dir) - - for filename in fnmatch.filter(directory_entries, "*.sql"): - sql_loc = os.path.join(sql_dir, filename) - logger.debug("Applying schema %s", sql_loc) - executescript(cur, sql_loc) - - cur.execute( - database_engine.convert_param_style( - "INSERT INTO schema_version (version, upgraded)" - " VALUES (?,?)" - ), - (max_current_ver, False,) - ) - - _upgrade_existing_database( - cur, - current_version=max_current_ver, - applied_delta_files=[], - upgraded=False, - database_engine=database_engine, - ) - - -def _upgrade_existing_database(cur, current_version, applied_delta_files, - upgraded, database_engine): - """Upgrades an existing database. - - Delta files can either be SQL stored in *.sql files, or python modules - in *.py. - - There can be multiple delta files per version. Synapse will keep track of - which delta files have been applied, and will apply any that haven't been - even if there has been no version bump. This is useful for development - where orthogonal schema changes may happen on separate branches. - - Different delta files for the same version *must* be orthogonal and give - the same result when applied in any order. No guarantees are made on the - order of execution of these scripts. - - This is a no-op of current_version == SCHEMA_VERSION. - - Example directory structure: - - schema/ - delta/ - 11/ - foo.sql - ... - 12/ - foo.sql - bar.py - ... - full_schemas/ - ... - - In the example, if current_version is 11, then foo.sql will be run if and - only if `upgraded` is True. Then `foo.sql` and `bar.py` would be run in - some arbitrary order. - - Args: - cur (Cursor) - current_version (int): The current version of the schema. - applied_delta_files (list): A list of deltas that have already been - applied. - upgraded (bool): Whether the current version was generated by having - applied deltas or from full schema file. If `True` the function - will never apply delta files for the given `current_version`, since - the current_version wasn't generated by applying those delta files. - """ - - if current_version > SCHEMA_VERSION: - raise ValueError( - "Cannot use this database as it is too " + - "new for the server to understand" - ) - - start_ver = current_version - if not upgraded: - start_ver += 1 - - logger.debug("applied_delta_files: %s", applied_delta_files) - - for v in range(start_ver, SCHEMA_VERSION + 1): - logger.debug("Upgrading schema to v%d", v) - - delta_dir = os.path.join(dir_path, "schema", "delta", str(v)) - - try: - directory_entries = os.listdir(delta_dir) - except OSError: - logger.exception("Could not open delta dir for version %d", v) - raise UpgradeDatabaseException( - "Could not open delta dir for version %d" % (v,) - ) - - directory_entries.sort() - for file_name in directory_entries: - relative_path = os.path.join(str(v), file_name) - logger.debug("Found file: %s", relative_path) - if relative_path in applied_delta_files: - continue - - absolute_path = os.path.join( - dir_path, "schema", "delta", relative_path, - ) - root_name, ext = os.path.splitext(file_name) - if ext == ".py": - # This is a python upgrade module. We need to import into some - # package and then execute its `run_upgrade` function. - module_name = "synapse.storage.v%d_%s" % ( - v, root_name - ) - with open(absolute_path) as python_file: - module = imp.load_source( - module_name, absolute_path, python_file - ) - logger.debug("Running script %s", relative_path) - module.run_upgrade(cur, database_engine) - elif ext == ".pyc": - # Sometimes .pyc files turn up anyway even though we've - # disabled their generation; e.g. from distribution package - # installers. Silently skip it - pass - elif ext == ".sql": - # A plain old .sql file, just read and execute it - logger.debug("Applying schema %s", relative_path) - executescript(cur, absolute_path) - else: - # Not a valid delta file. - logger.warn( - "Found directory entry that did not end in .py or" - " .sql: %s", - relative_path, - ) - continue - - # Mark as done. - cur.execute( - database_engine.convert_param_style( - "INSERT INTO applied_schema_deltas (version, file)" - " VALUES (?,?)", - ), - (v, relative_path) - ) - - cur.execute("DELETE FROM schema_version") - cur.execute( - database_engine.convert_param_style( - "INSERT INTO schema_version (version, upgraded)" - " VALUES (?,?)", - ), - (v, True) - ) - - -def get_statements(f): - statement_buffer = "" - in_comment = False # If we're in a /* ... */ style comment - - for line in f: - line = line.strip() - - if in_comment: - # Check if this line contains an end to the comment - comments = line.split("*/", 1) - if len(comments) == 1: - continue - line = comments[1] - in_comment = False - - # Remove inline block comments - line = re.sub(r"/\*.*\*/", " ", line) - - # Does this line start a comment? - comments = line.split("/*", 1) - if len(comments) > 1: - line = comments[0] - in_comment = True - - # Deal with line comments - line = line.split("--", 1)[0] - line = line.split("//", 1)[0] - - # Find *all* semicolons. We need to treat first and last entry - # specially. - statements = line.split(";") - - # We must prepend statement_buffer to the first statement - first_statement = "%s %s" % ( - statement_buffer.strip(), - statements[0].strip() - ) - statements[0] = first_statement - - # Every entry, except the last, is a full statement - for statement in statements[:-1]: - yield statement.strip() - - # The last entry did *not* end in a semicolon, so we store it for the - # next semicolon we find - statement_buffer = statements[-1].strip() - - -def executescript(txn, schema_path): - with open(schema_path, 'r') as f: - for statement in get_statements(f): - txn.execute(statement) - - -def _get_or_create_schema_state(txn, database_engine): - # Bluntly try creating the schema_version tables. - schema_path = os.path.join( - dir_path, "schema", "schema_version.sql", - ) - executescript(txn, schema_path) - - txn.execute("SELECT version, upgraded FROM schema_version") - row = txn.fetchone() - current_version = int(row[0]) if row else None - upgraded = bool(row[1]) if row else None - - if current_version: - txn.execute( - database_engine.convert_param_style( - "SELECT file FROM applied_schema_deltas WHERE version >= ?" - ), - (current_version,) - ) - applied_deltas = [d for d, in txn.fetchall()] - return current_version, applied_deltas, upgraded - - return None - - -def prepare_sqlite3_database(db_conn): - """This function should be called before `prepare_database` on sqlite3 - databases. - - Since we changed the way we store the current schema version and handle - updates to schemas, we need a way to upgrade from the old method to the - new. This only affects sqlite databases since they were the only ones - supported at the time. - """ - with db_conn: - schema_path = os.path.join( - dir_path, "schema", "schema_version.sql", - ) - create_schema = read_schema(schema_path) - db_conn.executescript(create_schema) - - c = db_conn.execute("SELECT * FROM schema_version") - rows = c.fetchall() - c.close() - - if not rows: - c = db_conn.execute("PRAGMA user_version") - row = c.fetchone() - c.close() - - if row and row[0]: - db_conn.execute( - "REPLACE INTO schema_version (version, upgraded)" - " VALUES (?,?)", - (row[0], False) - ) diff --git a/tests/utils.py b/tests/utils.py index 6eb575bd09..4da51291a4 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -16,7 +16,7 @@ from synapse.http.server import HttpServer from synapse.api.errors import cs_error, CodeMessageException, StoreError from synapse.api.constants import EventTypes -from synapse.storage.schema_prepare import prepare_database +from synapse.storage.prepare_database import prepare_database from synapse.storage.engines import create_engine from synapse.server import HomeServer -- cgit 1.5.1 From 771ca56c886dd08f707447cfff70acd3ba73e98c Mon Sep 17 00:00:00 2001 From: Daniel Wagner-Hall Date: Mon, 2 Nov 2015 15:31:57 +0000 Subject: Remove more unused parameters --- synapse/handlers/room.py | 1 - synapse/handlers/sync.py | 1 - synapse/storage/stream.py | 3 +-- tests/storage/test_redaction.py | 4 ---- tests/storage/test_stream.py | 4 ---- tests/utils.py | 2 +- 6 files changed, 2 insertions(+), 13 deletions(-) (limited to 'tests/utils.py') diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 36878a6c20..9184dcd048 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -827,7 +827,6 @@ class RoomEventSource(object): user_id=user.to_string(), from_key=from_key, to_key=to_key, - room_id=None, limit=limit, ) diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index eaa14f38df..4054efe555 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -342,7 +342,6 @@ class SyncHandler(BaseHandler): sync_config.user.to_string(), from_key=since_token.room_key, to_key=now_token.room_key, - room_id=None, limit=timeline_limit + 1, ) diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py index 15d4c2bf68..c728013f4c 100644 --- a/synapse/storage/stream.py +++ b/synapse/storage/stream.py @@ -158,8 +158,7 @@ class StreamStore(SQLBaseStore): defer.returnValue(results) @log_function - def get_room_events_stream(self, user_id, from_key, to_key, room_id, - limit=0): + def get_room_events_stream(self, user_id, from_key, to_key, limit=0): current_room_membership_sql = ( "SELECT m.room_id FROM room_memberships as m " " INNER JOIN current_state_events as c" diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py index b57006fcb4..dbf9700e6a 100644 --- a/tests/storage/test_redaction.py +++ b/tests/storage/test_redaction.py @@ -120,7 +120,6 @@ class RedactionTestCase(unittest.TestCase): self.u_alice.to_string(), start, end, - None, # Is currently ignored ) self.assertEqual(1, len(results)) @@ -149,7 +148,6 @@ class RedactionTestCase(unittest.TestCase): self.u_alice.to_string(), start, end, - None, # Is currently ignored ) self.assertEqual(1, len(results)) @@ -199,7 +197,6 @@ class RedactionTestCase(unittest.TestCase): self.u_alice.to_string(), start, end, - None, # Is currently ignored ) self.assertEqual(1, len(results)) @@ -228,7 +225,6 @@ class RedactionTestCase(unittest.TestCase): self.u_alice.to_string(), start, end, - None, # Is currently ignored ) self.assertEqual(1, len(results)) diff --git a/tests/storage/test_stream.py b/tests/storage/test_stream.py index a658a789aa..e5c2c5cc8e 100644 --- a/tests/storage/test_stream.py +++ b/tests/storage/test_stream.py @@ -68,7 +68,6 @@ class StreamStoreTestCase(unittest.TestCase): self.u_bob.to_string(), start, end, - None, # Is currently ignored ) self.assertEqual(1, len(results)) @@ -105,7 +104,6 @@ class StreamStoreTestCase(unittest.TestCase): self.u_alice.to_string(), start, end, - None, # Is currently ignored ) self.assertEqual(1, len(results)) @@ -147,7 +145,6 @@ class StreamStoreTestCase(unittest.TestCase): self.u_bob.to_string(), start, end, - None, # Is currently ignored ) # We should not get the message, as it happened *after* bob left. @@ -175,7 +172,6 @@ class StreamStoreTestCase(unittest.TestCase): self.u_bob.to_string(), start, end, - None, # Is currently ignored ) # We should not get the message, as it happened *after* bob left. diff --git a/tests/utils.py b/tests/utils.py index 4da51291a4..ca2c33cf8e 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -335,7 +335,7 @@ class MemoryDataStore(object): ] def get_room_events_stream(self, user_id=None, from_key=None, to_key=None, - room_id=None, limit=0, with_feedback=False): + limit=0, with_feedback=False): return ([], from_key) # TODO def get_joined_hosts_for_room(self, room_id): -- cgit 1.5.1 From 36c58b18a32f05a2f025bc916c14b9e2f78f439b Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Tue, 10 Nov 2015 15:51:40 +0000 Subject: Test for background updates --- tests/storage/test_background_update.py | 76 +++++++++++++++++++++++++++++++++ tests/utils.py | 3 ++ 2 files changed, 79 insertions(+) create mode 100644 tests/storage/test_background_update.py (limited to 'tests/utils.py') diff --git a/tests/storage/test_background_update.py b/tests/storage/test_background_update.py new file mode 100644 index 0000000000..29289fa9b4 --- /dev/null +++ b/tests/storage/test_background_update.py @@ -0,0 +1,76 @@ +from tests import unittest +from twisted.internet import defer + +from synapse.api.constants import EventTypes +from synapse.types import UserID, RoomID, RoomAlias + +from tests.utils import setup_test_homeserver + +from mock import Mock + +class BackgroundUpdateTestCase(unittest.TestCase): + + @defer.inlineCallbacks + def setUp(self): + hs = yield setup_test_homeserver() + self.store = hs.get_datastore() + self.clock = hs.get_clock() + + self.update_handler = Mock() + + yield self.store.register_background_update_handler( + "test_update", self.update_handler + ) + + @defer.inlineCallbacks + def test_do_background_update(self): + desired_count = 1000; + duration_ms = 42; + + @defer.inlineCallbacks + def update(progress, count): + self.clock.advance_time_msec(count * duration_ms) + progress = {"my_key": progress["my_key"] + 1} + yield self.store.runInteraction( + "update_progress", + self.store._background_update_progress_txn, + "test_update", + progress, + ) + defer.returnValue(count) + + self.update_handler.side_effect = update + + yield self.store.start_background_update("test_update", {"my_key": 1}) + + self.update_handler.reset_mock() + result = yield self.store.do_background_update( + duration_ms * desired_count + ) + self.assertIsNotNone(result) + self.update_handler.assert_called_once_with( + {"my_key": 1}, self.store.DEFAULT_BACKGROUND_BATCH_SIZE + ) + + @defer.inlineCallbacks + def update(progress, count): + yield self.store._end_background_update("test_update") + defer.returnValue(count) + + self.update_handler.side_effect = update + + self.update_handler.reset_mock() + result = yield self.store.do_background_update( + duration_ms * desired_count + ) + self.assertIsNotNone(result) + self.update_handler.assert_called_once_with( + {"my_key": 2}, desired_count + ) + + self.update_handler.reset_mock() + result = yield self.store.do_background_update( + duration_ms * desired_count + ) + self.assertIsNone(result) + self.assertFalse(self.update_handler.called) diff --git a/tests/utils.py b/tests/utils.py index ca2c33cf8e..91040c2efd 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -243,6 +243,9 @@ class MockClock(object): else: self.timers.append(t) + def advance_time_msec(self, ms): + self.advance_time(ms / 1000.) + class SQLiteMemoryDbPool(ConnectionPool, object): def __init__(self): -- cgit 1.5.1