From cecbd636e94f4e900ef6d246b62698ff1c8ee352 Mon Sep 17 00:00:00 2001 From: Daniel Wagner-Hall Date: Thu, 20 Aug 2015 16:21:35 +0100 Subject: /tokenrefresh POST endpoint This allows refresh tokens to be exchanged for (access_token, refresh_token). It also starts issuing them on login, though no clients currently interpret them. --- synapse/storage/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'synapse/storage/__init__.py') diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index f154b1c8ae..53673b3bf5 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -54,7 +54,7 @@ logger = logging.getLogger(__name__) # Remember to update this number every time a change is made to database # schema files, so the users will be informed on server restarts. -SCHEMA_VERSION = 22 +SCHEMA_VERSION = 23 dir_path = os.path.abspath(os.path.dirname(__file__)) -- cgit 1.4.1 From a0b181bd17cb7ec2a43ed2dbdeb1bb40f3f4373c Mon Sep 17 00:00:00 2001 From: Daniel Wagner-Hall Date: Tue, 25 Aug 2015 16:23:06 +0100 Subject: Remove completely unused concepts from codebase Removes device_id and ClientInfo device_id is never actually written, and the matrix.org DB has no non-null entries for it. Right now, it's just cluttering up code. This doesn't remove the columns from the database, because that's fiddly. --- synapse/api/auth.py | 17 ++++++--------- synapse/handlers/admin.py | 1 + synapse/handlers/message.py | 9 +++----- synapse/rest/client/v1/admin.py | 2 +- synapse/rest/client/v1/directory.py | 4 ++-- synapse/rest/client/v1/events.py | 4 ++-- synapse/rest/client/v1/initial_sync.py | 2 +- synapse/rest/client/v1/presence.py | 8 +++---- synapse/rest/client/v1/profile.py | 4 ++-- synapse/rest/client/v1/pusher.py | 4 ++-- synapse/rest/client/v1/room.py | 34 ++++++++++++++--------------- synapse/rest/client/v1/voip.py | 2 +- synapse/rest/client/v2_alpha/account.py | 4 ++-- synapse/rest/client/v2_alpha/filter.py | 4 ++-- synapse/rest/client/v2_alpha/keys.py | 6 ++--- synapse/rest/client/v2_alpha/receipts.py | 2 +- synapse/rest/client/v2_alpha/sync.py | 2 +- synapse/rest/media/v0/content_repository.py | 2 +- synapse/rest/media/v1/upload_resource.py | 2 +- synapse/storage/__init__.py | 7 +++--- synapse/storage/registration.py | 5 ++--- synapse/types.py | 4 ---- tests/api/test_auth.py | 8 +++---- tests/rest/client/v1/test_presence.py | 2 -- tests/rest/client/v1/test_rooms.py | 7 ------ tests/rest/client/v1/test_typing.py | 1 - tests/rest/client/v2_alpha/__init__.py | 1 - tests/storage/test_registration.py | 2 -- tests/utils.py | 3 +-- 29 files changed, 63 insertions(+), 90 deletions(-) (limited to 'synapse/storage/__init__.py') diff --git a/synapse/api/auth.py b/synapse/api/auth.py index 3d9237ccc3..1496db7dff 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -20,7 +20,7 @@ from twisted.internet import defer from synapse.api.constants import EventTypes, Membership, JoinRules from synapse.api.errors import AuthError, Codes, SynapseError from synapse.util.logutils import log_function -from synapse.types import UserID, ClientInfo +from synapse.types import UserID import logging @@ -322,9 +322,9 @@ class Auth(object): Args: request - An HTTP request with an access_token query parameter. Returns: - tuple : of UserID and device string: - User ID object of the user making the request - ClientInfo object of the client instance the user is using + tuple of: + UserID (str) + Access token ID (str) Raises: AuthError if no user by that token exists or the token is invalid. """ @@ -355,7 +355,7 @@ class Auth(object): request.authenticated_entity = user_id defer.returnValue( - (UserID.from_string(user_id), ClientInfo("", "")) + (UserID.from_string(user_id), "") ) return except KeyError: @@ -363,7 +363,6 @@ class Auth(object): user_info = yield self.get_user_by_access_token(access_token) user = user_info["user"] - device_id = user_info["device_id"] token_id = user_info["token_id"] ip_addr = self.hs.get_ip_from_request(request) @@ -375,14 +374,13 @@ class Auth(object): self.store.insert_client_ip( user=user, access_token=access_token, - device_id=user_info["device_id"], ip=ip_addr, user_agent=user_agent ) request.authenticated_entity = user.to_string() - defer.returnValue((user, ClientInfo(device_id, token_id))) + defer.returnValue((user, token_id,)) except KeyError: raise AuthError( self.TOKEN_NOT_FOUND_HTTP_STATUS, "Missing access token.", @@ -396,7 +394,7 @@ class Auth(object): Args: token (str): The access token to get the user by. Returns: - dict : dict that includes the user, device_id, and whether the + dict : dict that includes the user and whether the user is a server admin. Raises: AuthError if no user by that token exists or the token is invalid. @@ -409,7 +407,6 @@ class Auth(object): ) user_info = { "admin": bool(ret.get("admin", False)), - "device_id": ret.get("device_id"), "user": UserID.from_string(ret.get("name")), "token_id": ret.get("token_id", None), } diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py index 1c9e7152c7..d852a18555 100644 --- a/synapse/handlers/admin.py +++ b/synapse/handlers/admin.py @@ -34,6 +34,7 @@ class AdminHandler(BaseHandler): d = {} for r in res: + # Note that device_id is always None device = d.setdefault(r["device_id"], {}) session = device.setdefault(r["access_token"], []) session.append({ diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index f12465fa2c..23b779ad7c 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -183,7 +183,7 @@ class MessageHandler(BaseHandler): @defer.inlineCallbacks def create_and_send_event(self, event_dict, ratelimit=True, - client=None, txn_id=None): + token_id=None, txn_id=None): """ Given a dict from a client, create and handle a new event. Creates an FrozenEvent object, filling out auth_events, prev_events, @@ -217,11 +217,8 @@ class MessageHandler(BaseHandler): builder.content ) - if client is not None: - if client.token_id is not None: - builder.internal_metadata.token_id = client.token_id - if client.device_id is not None: - builder.internal_metadata.device_id = client.device_id + if token_id is not None: + builder.internal_metadata.token_id = token_id if txn_id is not None: builder.internal_metadata.txn_id = txn_id diff --git a/synapse/rest/client/v1/admin.py b/synapse/rest/client/v1/admin.py index 2ce754b028..504b63eab4 100644 --- a/synapse/rest/client/v1/admin.py +++ b/synapse/rest/client/v1/admin.py @@ -31,7 +31,7 @@ class WhoisRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_GET(self, request, user_id): target_user = UserID.from_string(user_id) - auth_user, client = yield self.auth.get_user_by_req(request) + auth_user, _ = yield self.auth.get_user_by_req(request) is_admin = yield self.auth.is_server_admin(auth_user) if not is_admin and target_user != auth_user: diff --git a/synapse/rest/client/v1/directory.py b/synapse/rest/client/v1/directory.py index 6758a888b3..4dcda57c1b 100644 --- a/synapse/rest/client/v1/directory.py +++ b/synapse/rest/client/v1/directory.py @@ -69,7 +69,7 @@ class ClientDirectoryServer(ClientV1RestServlet): try: # try to auth as a user - user, client = yield self.auth.get_user_by_req(request) + user, _ = yield self.auth.get_user_by_req(request) try: user_id = user.to_string() yield dir_handler.create_association( @@ -116,7 +116,7 @@ class ClientDirectoryServer(ClientV1RestServlet): # fallback to default user behaviour if they aren't an AS pass - user, client = yield self.auth.get_user_by_req(request) + user, _ = yield self.auth.get_user_by_req(request) is_admin = yield self.auth.is_server_admin(user) if not is_admin: diff --git a/synapse/rest/client/v1/events.py b/synapse/rest/client/v1/events.py index 77b7c25a03..582148b659 100644 --- a/synapse/rest/client/v1/events.py +++ b/synapse/rest/client/v1/events.py @@ -34,7 +34,7 @@ class EventStreamRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_GET(self, request): - auth_user, client = yield self.auth.get_user_by_req(request) + auth_user, _ = yield self.auth.get_user_by_req(request) try: handler = self.handlers.event_stream_handler pagin_config = PaginationConfig.from_request(request) @@ -71,7 +71,7 @@ class EventRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_GET(self, request, event_id): - auth_user, client = yield self.auth.get_user_by_req(request) + auth_user, _ = yield self.auth.get_user_by_req(request) handler = self.handlers.event_handler event = yield handler.get_event(auth_user, event_id) diff --git a/synapse/rest/client/v1/initial_sync.py b/synapse/rest/client/v1/initial_sync.py index 4a259bba64..4ea4da653c 100644 --- a/synapse/rest/client/v1/initial_sync.py +++ b/synapse/rest/client/v1/initial_sync.py @@ -25,7 +25,7 @@ class InitialSyncRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_GET(self, request): - user, client = yield self.auth.get_user_by_req(request) + user, _ = yield self.auth.get_user_by_req(request) with_feedback = "feedback" in request.args as_client_event = "raw" not in request.args pagination_config = PaginationConfig.from_request(request) diff --git a/synapse/rest/client/v1/presence.py b/synapse/rest/client/v1/presence.py index 78d4f2b128..a770efd841 100644 --- a/synapse/rest/client/v1/presence.py +++ b/synapse/rest/client/v1/presence.py @@ -32,7 +32,7 @@ class PresenceStatusRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_GET(self, request, user_id): - auth_user, client = yield self.auth.get_user_by_req(request) + auth_user, _ = yield self.auth.get_user_by_req(request) user = UserID.from_string(user_id) state = yield self.handlers.presence_handler.get_state( @@ -42,7 +42,7 @@ class PresenceStatusRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_PUT(self, request, user_id): - auth_user, client = yield self.auth.get_user_by_req(request) + auth_user, _ = yield self.auth.get_user_by_req(request) user = UserID.from_string(user_id) state = {} @@ -77,7 +77,7 @@ class PresenceListRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_GET(self, request, user_id): - auth_user, client = yield self.auth.get_user_by_req(request) + auth_user, _ = yield self.auth.get_user_by_req(request) user = UserID.from_string(user_id) if not self.hs.is_mine(user): @@ -97,7 +97,7 @@ class PresenceListRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_POST(self, request, user_id): - auth_user, client = yield self.auth.get_user_by_req(request) + auth_user, _ = yield self.auth.get_user_by_req(request) user = UserID.from_string(user_id) if not self.hs.is_mine(user): diff --git a/synapse/rest/client/v1/profile.py b/synapse/rest/client/v1/profile.py index 1e77eb49cf..fdde88a60d 100644 --- a/synapse/rest/client/v1/profile.py +++ b/synapse/rest/client/v1/profile.py @@ -37,7 +37,7 @@ class ProfileDisplaynameRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_PUT(self, request, user_id): - auth_user, client = yield self.auth.get_user_by_req(request) + auth_user, _ = yield self.auth.get_user_by_req(request) user = UserID.from_string(user_id) try: @@ -70,7 +70,7 @@ class ProfileAvatarURLRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_PUT(self, request, user_id): - auth_user, client = yield self.auth.get_user_by_req(request) + auth_user, _ = yield self.auth.get_user_by_req(request) user = UserID.from_string(user_id) try: diff --git a/synapse/rest/client/v1/pusher.py b/synapse/rest/client/v1/pusher.py index c83287c028..3aabc93b8b 100644 --- a/synapse/rest/client/v1/pusher.py +++ b/synapse/rest/client/v1/pusher.py @@ -27,7 +27,7 @@ class PusherRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_POST(self, request): - user, client = yield self.auth.get_user_by_req(request) + user, token_id = yield self.auth.get_user_by_req(request) content = _parse_json(request) @@ -65,7 +65,7 @@ class PusherRestServlet(ClientV1RestServlet): try: yield pusher_pool.add_pusher( user_name=user.to_string(), - access_token=client.token_id, + access_token=token_id, profile_tag=content['profile_tag'], kind=content['kind'], app_id=content['app_id'], diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index b4a70cba99..c9c27dd5a0 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -62,7 +62,7 @@ class RoomCreateRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_POST(self, request): - auth_user, client = yield self.auth.get_user_by_req(request) + auth_user, _ = yield self.auth.get_user_by_req(request) room_config = self.get_room_config(request) info = yield self.make_room(room_config, auth_user, None) @@ -125,7 +125,7 @@ class RoomStateEventRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_GET(self, request, room_id, event_type, state_key): - user, client = yield self.auth.get_user_by_req(request) + user, _ = yield self.auth.get_user_by_req(request) msg_handler = self.handlers.message_handler data = yield msg_handler.get_room_data( @@ -143,7 +143,7 @@ class RoomStateEventRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_PUT(self, request, room_id, event_type, state_key, txn_id=None): - user, client = yield self.auth.get_user_by_req(request) + user, token_id = yield self.auth.get_user_by_req(request) content = _parse_json(request) @@ -159,7 +159,7 @@ class RoomStateEventRestServlet(ClientV1RestServlet): msg_handler = self.handlers.message_handler yield msg_handler.create_and_send_event( - event_dict, client=client, txn_id=txn_id, + event_dict, token_id=token_id, txn_id=txn_id, ) defer.returnValue((200, {})) @@ -175,7 +175,7 @@ class RoomSendEventRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_POST(self, request, room_id, event_type, txn_id=None): - user, client = yield self.auth.get_user_by_req(request) + user, token_id = yield self.auth.get_user_by_req(request) content = _parse_json(request) msg_handler = self.handlers.message_handler @@ -186,7 +186,7 @@ class RoomSendEventRestServlet(ClientV1RestServlet): "room_id": room_id, "sender": user.to_string(), }, - client=client, + token_id=token_id, txn_id=txn_id, ) @@ -220,7 +220,7 @@ class JoinRoomAliasServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_POST(self, request, room_identifier, txn_id=None): - user, client = yield self.auth.get_user_by_req(request) + user, token_id = yield self.auth.get_user_by_req(request) # the identifier could be a room alias or a room id. Try one then the # other if it fails to parse, without swallowing other valid @@ -250,7 +250,7 @@ class JoinRoomAliasServlet(ClientV1RestServlet): "sender": user.to_string(), "state_key": user.to_string(), }, - client=client, + token_id=token_id, txn_id=txn_id, ) @@ -289,7 +289,7 @@ class RoomMemberListRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_GET(self, request, room_id): # TODO support Pagination stream API (limit/tokens) - user, client = yield self.auth.get_user_by_req(request) + user, _ = yield self.auth.get_user_by_req(request) handler = self.handlers.room_member_handler members = yield handler.get_room_members_as_pagination_chunk( room_id=room_id, @@ -317,7 +317,7 @@ class RoomMessageListRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_GET(self, request, room_id): - user, client = yield self.auth.get_user_by_req(request) + user, _ = yield self.auth.get_user_by_req(request) pagination_config = PaginationConfig.from_request( request, default_limit=10, ) @@ -341,7 +341,7 @@ class RoomStateRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_GET(self, request, room_id): - user, client = yield self.auth.get_user_by_req(request) + user, _ = yield self.auth.get_user_by_req(request) handler = self.handlers.message_handler # Get all the current state for this room events = yield handler.get_state_events( @@ -357,7 +357,7 @@ class RoomInitialSyncRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_GET(self, request, room_id): - user, client = yield self.auth.get_user_by_req(request) + user, _ = yield self.auth.get_user_by_req(request) pagination_config = PaginationConfig.from_request(request) content = yield self.handlers.message_handler.room_initial_sync( room_id=room_id, @@ -402,7 +402,7 @@ class RoomMembershipRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_POST(self, request, room_id, membership_action, txn_id=None): - user, client = yield self.auth.get_user_by_req(request) + user, token_id = yield self.auth.get_user_by_req(request) content = _parse_json(request) @@ -427,7 +427,7 @@ class RoomMembershipRestServlet(ClientV1RestServlet): "sender": user.to_string(), "state_key": state_key, }, - client=client, + token_id=token_id, txn_id=txn_id, ) @@ -457,7 +457,7 @@ class RoomRedactEventRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_POST(self, request, room_id, event_id, txn_id=None): - user, client = yield self.auth.get_user_by_req(request) + user, token_id = yield self.auth.get_user_by_req(request) content = _parse_json(request) msg_handler = self.handlers.message_handler @@ -469,7 +469,7 @@ class RoomRedactEventRestServlet(ClientV1RestServlet): "sender": user.to_string(), "redacts": event_id, }, - client=client, + token_id=token_id, txn_id=txn_id, ) @@ -497,7 +497,7 @@ class RoomTypingRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_PUT(self, request, room_id, user_id): - auth_user, client = yield self.auth.get_user_by_req(request) + auth_user, _ = yield self.auth.get_user_by_req(request) room_id = urllib.unquote(room_id) target_user = UserID.from_string(urllib.unquote(user_id)) diff --git a/synapse/rest/client/v1/voip.py b/synapse/rest/client/v1/voip.py index 11d08fbced..4ae2d81b70 100644 --- a/synapse/rest/client/v1/voip.py +++ b/synapse/rest/client/v1/voip.py @@ -28,7 +28,7 @@ class VoipRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_GET(self, request): - auth_user, client = yield self.auth.get_user_by_req(request) + auth_user, _ = yield self.auth.get_user_by_req(request) turnUris = self.hs.config.turn_uris turnSecret = self.hs.config.turn_shared_secret diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py index 522a312c9e..b5edffdb60 100644 --- a/synapse/rest/client/v2_alpha/account.py +++ b/synapse/rest/client/v2_alpha/account.py @@ -55,7 +55,7 @@ class PasswordRestServlet(RestServlet): if LoginType.PASSWORD in result: # if using password, they should also be logged in - auth_user, client = yield self.auth.get_user_by_req(request) + auth_user, _ = yield self.auth.get_user_by_req(request) if auth_user.to_string() != result[LoginType.PASSWORD]: raise LoginError(400, "", Codes.UNKNOWN) user_id = auth_user.to_string() @@ -119,7 +119,7 @@ class ThreepidRestServlet(RestServlet): raise SynapseError(400, "Missing param", Codes.MISSING_PARAM) threePidCreds = body['threePidCreds'] - auth_user, client = yield self.auth.get_user_by_req(request) + auth_user, _ = yield self.auth.get_user_by_req(request) threepid = yield self.identity_handler.threepid_from_creds(threePidCreds) diff --git a/synapse/rest/client/v2_alpha/filter.py b/synapse/rest/client/v2_alpha/filter.py index 703250cea8..f8f91b63f5 100644 --- a/synapse/rest/client/v2_alpha/filter.py +++ b/synapse/rest/client/v2_alpha/filter.py @@ -40,7 +40,7 @@ class GetFilterRestServlet(RestServlet): @defer.inlineCallbacks def on_GET(self, request, user_id, filter_id): target_user = UserID.from_string(user_id) - auth_user, client = yield self.auth.get_user_by_req(request) + auth_user, _ = yield self.auth.get_user_by_req(request) if target_user != auth_user: raise AuthError(403, "Cannot get filters for other users") @@ -76,7 +76,7 @@ class CreateFilterRestServlet(RestServlet): @defer.inlineCallbacks def on_POST(self, request, user_id): target_user = UserID.from_string(user_id) - auth_user, client = yield self.auth.get_user_by_req(request) + auth_user, _ = yield self.auth.get_user_by_req(request) if target_user != auth_user: raise AuthError(403, "Cannot create filters for other users") diff --git a/synapse/rest/client/v2_alpha/keys.py b/synapse/rest/client/v2_alpha/keys.py index 718928eedd..ec1145454f 100644 --- a/synapse/rest/client/v2_alpha/keys.py +++ b/synapse/rest/client/v2_alpha/keys.py @@ -63,7 +63,7 @@ class KeyUploadServlet(RestServlet): @defer.inlineCallbacks def on_POST(self, request, device_id): - auth_user, client_info = yield self.auth.get_user_by_req(request) + auth_user, _ = yield self.auth.get_user_by_req(request) user_id = auth_user.to_string() # TODO: Check that the device_id matches that in the authentication # or derive the device_id from the authentication instead. @@ -108,7 +108,7 @@ class KeyUploadServlet(RestServlet): @defer.inlineCallbacks def on_GET(self, request, device_id): - auth_user, client_info = yield self.auth.get_user_by_req(request) + auth_user, _ = yield self.auth.get_user_by_req(request) user_id = auth_user.to_string() result = yield self.store.count_e2e_one_time_keys(user_id, device_id) @@ -180,7 +180,7 @@ class KeyQueryServlet(RestServlet): @defer.inlineCallbacks def on_GET(self, request, user_id, device_id): - auth_user, client_info = yield self.auth.get_user_by_req(request) + auth_user, _ = yield self.auth.get_user_by_req(request) auth_user_id = auth_user.to_string() user_id = user_id if user_id else auth_user_id device_ids = [device_id] if device_id else [] diff --git a/synapse/rest/client/v2_alpha/receipts.py b/synapse/rest/client/v2_alpha/receipts.py index 40406e2ede..52e99f54d5 100644 --- a/synapse/rest/client/v2_alpha/receipts.py +++ b/synapse/rest/client/v2_alpha/receipts.py @@ -39,7 +39,7 @@ class ReceiptRestServlet(RestServlet): @defer.inlineCallbacks def on_POST(self, request, room_id, receipt_type, event_id): - user, client = yield self.auth.get_user_by_req(request) + user, _ = yield self.auth.get_user_by_req(request) yield self.receipts_handler.received_client_receipt( room_id, diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py index f2fd0b9f32..83a257b969 100644 --- a/synapse/rest/client/v2_alpha/sync.py +++ b/synapse/rest/client/v2_alpha/sync.py @@ -87,7 +87,7 @@ class SyncRestServlet(RestServlet): @defer.inlineCallbacks def on_GET(self, request): - user, client = yield self.auth.get_user_by_req(request) + user, _ = yield self.auth.get_user_by_req(request) timeout = parse_integer(request, "timeout", default=0) limit = parse_integer(request, "limit", required=True) diff --git a/synapse/rest/media/v0/content_repository.py b/synapse/rest/media/v0/content_repository.py index e77a20fb2e..c28dc86cd7 100644 --- a/synapse/rest/media/v0/content_repository.py +++ b/synapse/rest/media/v0/content_repository.py @@ -66,7 +66,7 @@ class ContentRepoResource(resource.Resource): @defer.inlineCallbacks def map_request_to_name(self, request): # auth the user - auth_user, client = yield self.auth.get_user_by_req(request) + auth_user, _ = yield self.auth.get_user_by_req(request) # namespace all file uploads on the user prefix = base64.urlsafe_b64encode( diff --git a/synapse/rest/media/v1/upload_resource.py b/synapse/rest/media/v1/upload_resource.py index cdd1d44e07..439d5a30a8 100644 --- a/synapse/rest/media/v1/upload_resource.py +++ b/synapse/rest/media/v1/upload_resource.py @@ -70,7 +70,7 @@ class UploadResource(BaseMediaResource): @request_handler @defer.inlineCallbacks def _async_render_POST(self, request): - auth_user, client = yield self.auth.get_user_by_req(request) + auth_user, _ = yield self.auth.get_user_by_req(request) # TODO: The checks here are a bit late. The content will have # already been uploaded to a tmp file at this point content_length = request.getHeader("Content-Length") diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index 53673b3bf5..77cb1dbd81 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -94,9 +94,9 @@ class DataStore(RoomMemberStore, RoomStore, ) @defer.inlineCallbacks - def insert_client_ip(self, user, access_token, device_id, ip, user_agent): + def insert_client_ip(self, user, access_token, ip, user_agent): now = int(self._clock.time_msec()) - key = (user.to_string(), access_token, device_id, ip) + key = (user.to_string(), access_token, ip) try: last_seen = self.client_ip_last_seen.get(key) @@ -120,7 +120,6 @@ class DataStore(RoomMemberStore, RoomStore, "user_agent": user_agent, }, values={ - "device_id": device_id, "last_seen": now, }, desc="insert_client_ip", @@ -132,7 +131,7 @@ class DataStore(RoomMemberStore, RoomStore, table="user_ips", keyvalues={"user_id": user.to_string()}, retcols=[ - "device_id", "access_token", "ip", "user_agent", "last_seen" + "access_token", "ip", "user_agent", "last_seen" ], desc="get_user_ip_and_agents", ) diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py index f632306688..240d14c4d0 100644 --- a/synapse/storage/registration.py +++ b/synapse/storage/registration.py @@ -163,7 +163,7 @@ class RegistrationStore(SQLBaseStore): Args: token (str): The access token of a user. Returns: - dict: Including the name (user_id), device_id and whether they are + dict: Including the name (user_id) and whether they are an admin. Raises: StoreError if no user was found. @@ -228,8 +228,7 @@ class RegistrationStore(SQLBaseStore): def _query_for_auth(self, txn, token): sql = ( - "SELECT users.name, users.admin," - " access_tokens.device_id, access_tokens.id as token_id" + "SELECT users.name, users.admin, access_tokens.id as token_id" " FROM users" " INNER JOIN access_tokens on users.name = access_tokens.user_id" " WHERE token = ?" diff --git a/synapse/types.py b/synapse/types.py index e190374cbd..9cffc33d27 100644 --- a/synapse/types.py +++ b/synapse/types.py @@ -209,7 +209,3 @@ class RoomStreamToken(namedtuple("_StreamToken", "topological stream")): return "t%d-%d" % (self.topological, self.stream) else: return "s%d" % (self.stream,) - - -# token_id is the primary key ID of the access token, not the access token itself. -ClientInfo = namedtuple("ClientInfo", ("device_id", "token_id")) diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py index 3343c635cc..777eb0395e 100644 --- a/tests/api/test_auth.py +++ b/tests/api/test_auth.py @@ -40,7 +40,6 @@ class AuthTestCase(unittest.TestCase): self.store.get_app_service_by_token = Mock(return_value=None) user_info = { "name": self.test_user, - "device_id": "nothing", "token_id": "ditto", "admin": False } @@ -49,7 +48,7 @@ class AuthTestCase(unittest.TestCase): request = Mock(args={}) request.args["access_token"] = [self.test_token] request.requestHeaders.getRawHeaders = Mock(return_value=[""]) - (user, info) = yield self.auth.get_user_by_req(request) + (user, _) = yield self.auth.get_user_by_req(request) self.assertEquals(user.to_string(), self.test_user) def test_get_user_by_req_user_bad_token(self): @@ -66,7 +65,6 @@ class AuthTestCase(unittest.TestCase): self.store.get_app_service_by_token = Mock(return_value=None) user_info = { "name": self.test_user, - "device_id": "nothing", "token_id": "ditto", "admin": False } @@ -86,7 +84,7 @@ class AuthTestCase(unittest.TestCase): request = Mock(args={}) request.args["access_token"] = [self.test_token] request.requestHeaders.getRawHeaders = Mock(return_value=[""]) - (user, info) = yield self.auth.get_user_by_req(request) + (user, _) = yield self.auth.get_user_by_req(request) self.assertEquals(user.to_string(), self.test_user) def test_get_user_by_req_appservice_bad_token(self): @@ -121,7 +119,7 @@ class AuthTestCase(unittest.TestCase): request.args["access_token"] = [self.test_token] request.args["user_id"] = [masquerading_user_id] request.requestHeaders.getRawHeaders = Mock(return_value=[""]) - (user, info) = yield self.auth.get_user_by_req(request) + (user, _) = yield self.auth.get_user_by_req(request) self.assertEquals(user.to_string(), masquerading_user_id) def test_get_user_by_req_appservice_valid_token_bad_user_id(self): diff --git a/tests/rest/client/v1/test_presence.py b/tests/rest/client/v1/test_presence.py index 0b78a82a66..4039a86d85 100644 --- a/tests/rest/client/v1/test_presence.py +++ b/tests/rest/client/v1/test_presence.py @@ -74,7 +74,6 @@ class PresenceStateTestCase(unittest.TestCase): return { "user": UserID.from_string(myid), "admin": False, - "device_id": None, "token_id": 1, } @@ -163,7 +162,6 @@ class PresenceListTestCase(unittest.TestCase): return { "user": UserID.from_string(myid), "admin": False, - "device_id": None, "token_id": 1, } diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py index 2e55cc08a1..dd1e67e0f9 100644 --- a/tests/rest/client/v1/test_rooms.py +++ b/tests/rest/client/v1/test_rooms.py @@ -58,7 +58,6 @@ class RoomPermissionsTestCase(RestTestCase): return { "user": UserID.from_string(self.auth_user_id), "admin": False, - "device_id": None, "token_id": 1, } hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token @@ -445,7 +444,6 @@ class RoomsMemberListTestCase(RestTestCase): return { "user": UserID.from_string(self.auth_user_id), "admin": False, - "device_id": None, "token_id": 1, } hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token @@ -525,7 +523,6 @@ class RoomsCreateTestCase(RestTestCase): return { "user": UserID.from_string(self.auth_user_id), "admin": False, - "device_id": None, "token_id": 1, } hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token @@ -618,7 +615,6 @@ class RoomTopicTestCase(RestTestCase): return { "user": UserID.from_string(self.auth_user_id), "admin": False, - "device_id": None, "token_id": 1, } @@ -725,7 +721,6 @@ class RoomMemberStateTestCase(RestTestCase): return { "user": UserID.from_string(self.auth_user_id), "admin": False, - "device_id": None, "token_id": 1, } hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token @@ -852,7 +847,6 @@ class RoomMessagesTestCase(RestTestCase): return { "user": UserID.from_string(self.auth_user_id), "admin": False, - "device_id": None, "token_id": 1, } hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token @@ -949,7 +943,6 @@ class RoomInitialSyncTestCase(RestTestCase): return { "user": UserID.from_string(self.auth_user_id), "admin": False, - "device_id": None, "token_id": 1, } hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token diff --git a/tests/rest/client/v1/test_typing.py b/tests/rest/client/v1/test_typing.py index dc8bbaaf0e..0f70ce81dc 100644 --- a/tests/rest/client/v1/test_typing.py +++ b/tests/rest/client/v1/test_typing.py @@ -65,7 +65,6 @@ class RoomTypingTestCase(RestTestCase): return { "user": UserID.from_string(self.auth_user_id), "admin": False, - "device_id": None, "token_id": 1, } diff --git a/tests/rest/client/v2_alpha/__init__.py b/tests/rest/client/v2_alpha/__init__.py index 15568b36cd..badb59f080 100644 --- a/tests/rest/client/v2_alpha/__init__.py +++ b/tests/rest/client/v2_alpha/__init__.py @@ -47,7 +47,6 @@ class V2AlphaRestTestCase(unittest.TestCase): return { "user": UserID.from_string(self.USER_ID), "admin": False, - "device_id": None, "token_id": 1, } hs.get_auth().get_user_by_access_token = _get_user_by_access_token diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py index a4f929796a..54fe10d58f 100644 --- a/tests/storage/test_registration.py +++ b/tests/storage/test_registration.py @@ -54,7 +54,6 @@ class RegistrationStoreTestCase(unittest.TestCase): self.assertDictContainsSubset( { "admin": 0, - "device_id": None, "name": self.user_id, }, result @@ -72,7 +71,6 @@ class RegistrationStoreTestCase(unittest.TestCase): self.assertDictContainsSubset( { "admin": 0, - "device_id": None, "name": self.user_id, }, result diff --git a/tests/utils.py b/tests/utils.py index d0fba2252d..ff560ef342 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -282,7 +282,6 @@ class MemoryDataStore(object): return { "name": self.tokens_to_users[token], "admin": 0, - "device_id": None, } except: raise StoreError(400, "User does not exist.") @@ -380,7 +379,7 @@ class MemoryDataStore(object): def get_ops_levels(self, room_id): return defer.succeed((5, 5, 5)) - def insert_client_ip(self, user, device_id, access_token, ip, user_agent): + def insert_client_ip(self, user, access_token, ip, user_agent): return defer.succeed(None) -- cgit 1.4.1 From 7213588083dd9a721b0cd623fe22b308f25f19a5 Mon Sep 17 00:00:00 2001 From: Daniel Wagner-Hall Date: Tue, 22 Sep 2015 12:57:40 +0100 Subject: Implement configurable stats reporting SYN-287 This requires that HS owners either opt in or out of stats reporting. When --generate-config is passed, --report-stats must be specified If an already-generated config is used, and doesn't have the report_stats key, it is requested to be set. --- synapse/app/homeserver.py | 35 ++++++- synapse/app/synctl.py | 12 ++- synapse/config/_base.py | 45 +++++++- synapse/config/appservice.py | 2 +- synapse/config/captcha.py | 2 +- synapse/config/database.py | 2 +- synapse/config/key.py | 2 +- synapse/config/logger.py | 2 +- synapse/config/metrics.py | 8 +- synapse/config/ratelimiting.py | 2 +- synapse/config/registration.py | 2 +- synapse/config/repository.py | 2 +- synapse/config/saml2.py | 2 +- synapse/config/server.py | 2 +- synapse/config/tls.py | 2 +- synapse/config/voip.py | 2 +- synapse/storage/__init__.py | 20 +++- synapse/storage/events.py | 58 ++++++++++- synapse/storage/registration.py | 12 +++ .../storage/schema/delta/24/stats_reporting.sql | 22 ++++ tests/storage/event_injector.py | 81 ++++++++++++++ tests/storage/test_events.py | 116 +++++++++++++++++++++ tests/storage/test_room.py | 2 +- tests/storage/test_stream.py | 68 +++--------- 24 files changed, 425 insertions(+), 78 deletions(-) create mode 100644 synapse/storage/schema/delta/24/stats_reporting.sql create mode 100644 tests/storage/event_injector.py create mode 100644 tests/storage/test_events.py (limited to 'synapse/storage/__init__.py') diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index 15c0a4a003..b4429bd4f3 100755 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -42,7 +42,7 @@ from synapse.storage import ( from synapse.server import HomeServer -from twisted.internet import reactor +from twisted.internet import reactor, task, defer from twisted.application import service from twisted.enterprise import adbapi from twisted.web.resource import Resource, EncodingResourceWrapper @@ -677,6 +677,39 @@ def run(hs): ThreadPool._worker = profile(ThreadPool._worker) reactor.run = profile(reactor.run) + start_time = hs.get_clock().time() + + @defer.inlineCallbacks + def phone_stats_home(): + now = int(hs.get_clock().time()) + uptime = int(now - start_time) + if uptime < 0: + uptime = 0 + + stats = {} + stats["homeserver"] = hs.config.server_name + stats["timestamp"] = now + stats["uptime_seconds"] = uptime + stats["total_users"] = yield hs.get_datastore().count_all_users() + + all_rooms = yield hs.get_datastore().get_rooms(False) + stats["total_room_count"] = len(all_rooms) + + stats["daily_active_users"] = yield hs.get_datastore().count_daily_users() + daily_messages = yield hs.get_datastore().count_daily_messages() + if daily_messages is not None: + stats["daily_messages"] = daily_messages + + logger.info("Reporting stats to matrix.org: %s" % (stats,)) + hs.get_simple_http_client().put_json( + "https://matrix.org/report-usage-stats/push", + stats + ) + + if hs.config.report_stats: + phone_home_task = task.LoopingCall(phone_stats_home) + phone_home_task.start(60 * 60 * 24, now=False) + def in_thread(): with LoggingContext("run"): change_resource_limit(hs.config.soft_file_limit) diff --git a/synapse/app/synctl.py b/synapse/app/synctl.py index 1f7d543c31..6bcc437591 100755 --- a/synapse/app/synctl.py +++ b/synapse/app/synctl.py @@ -25,6 +25,7 @@ SYNAPSE = ["python", "-B", "-m", "synapse.app.homeserver"] CONFIGFILE = "homeserver.yaml" GREEN = "\x1b[1;32m" +RED = "\x1b[1;31m" NORMAL = "\x1b[m" if not os.path.exists(CONFIGFILE): @@ -45,8 +46,15 @@ def start(): print "Starting ...", args = SYNAPSE args.extend(["--daemonize", "-c", CONFIGFILE]) - subprocess.check_call(args) - print GREEN + "started" + NORMAL + try: + subprocess.check_call(args) + print GREEN + "started" + NORMAL + except subprocess.CalledProcessError as e: + print ( + RED + + "error starting (exit code: %d); see above for logs" % e.returncode + + NORMAL + ) def stop(): diff --git a/synapse/config/_base.py b/synapse/config/_base.py index 8a75c48733..b9983f72a2 100644 --- a/synapse/config/_base.py +++ b/synapse/config/_base.py @@ -26,6 +26,16 @@ class ConfigError(Exception): class Config(object): + stats_reporting_begging_spiel = ( + "We would really appreciate it if you could help our project out by " + "reporting anonymized usage statistics from your homeserver. Only very " + "basic aggregate data (e.g. number of users) will be reported, but it " + "helps us to track the growth of the Matrix community, and helps us to " + "make Matrix a success, as well as to convince other networks that they " + "should peer with us.\n" + "Thank you." + ) + @staticmethod def parse_size(value): if isinstance(value, int) or isinstance(value, long): @@ -111,11 +121,14 @@ class Config(object): results.append(getattr(cls, name)(self, *args, **kargs)) return results - def generate_config(self, config_dir_path, server_name): + def generate_config(self, config_dir_path, server_name, report_stats=None): default_config = "# vim:ft=yaml\n" default_config += "\n\n".join(dedent(conf) for conf in self.invoke_all( - "default_config", config_dir_path, server_name + "default_config", + config_dir_path=config_dir_path, + server_name=server_name, + report_stats=report_stats, )) config = yaml.load(default_config) @@ -139,6 +152,12 @@ class Config(object): action="store_true", help="Generate a config file for the server name" ) + config_parser.add_argument( + "--report-stats", + action="store", + help="Stuff", + choices=["yes", "no"] + ) config_parser.add_argument( "--generate-keys", action="store_true", @@ -189,6 +208,11 @@ class Config(object): config_files.append(config_path) if config_args.generate_config: + if config_args.report_stats is None: + config_parser.error( + "Please specify either --report-stats=yes or --report-stats=no\n\n" + + cls.stats_reporting_begging_spiel + ) if not config_files: config_parser.error( "Must supply a config file.\nA config file can be automatically" @@ -211,7 +235,9 @@ class Config(object): os.makedirs(config_dir_path) with open(config_path, "wb") as config_file: config_bytes, config = obj.generate_config( - config_dir_path, server_name + config_dir_path=config_dir_path, + server_name=server_name, + report_stats=(config_args.report_stats == "yes"), ) obj.invoke_all("generate_files", config) config_file.write(config_bytes) @@ -261,9 +287,20 @@ class Config(object): specified_config.update(yaml_config) server_name = specified_config["server_name"] - _, config = obj.generate_config(config_dir_path, server_name) + _, config = obj.generate_config( + config_dir_path=config_dir_path, + server_name=server_name + ) config.pop("log_config") config.update(specified_config) + if "report_stats" not in config: + sys.stderr.write( + "Please opt in or out of reporting anonymized homeserver usage " + "statistics, by setting the report_stats key in your config file " + " ( " + config_path + " ) " + + "to either True or False.\n\n" + + Config.stats_reporting_begging_spiel + "\n") + sys.exit(1) if generate_keys: obj.invoke_all("generate_files", config) diff --git a/synapse/config/appservice.py b/synapse/config/appservice.py index 38f41933b7..b8d301995e 100644 --- a/synapse/config/appservice.py +++ b/synapse/config/appservice.py @@ -20,7 +20,7 @@ class AppServiceConfig(Config): def read_config(self, config): self.app_service_config_files = config.get("app_service_config_files", []) - def default_config(cls, config_dir_path, server_name): + def default_config(cls, **kwargs): return """\ # A list of application service config file to use app_service_config_files: [] diff --git a/synapse/config/captcha.py b/synapse/config/captcha.py index 15a132b4e3..dd92fcd0dc 100644 --- a/synapse/config/captcha.py +++ b/synapse/config/captcha.py @@ -24,7 +24,7 @@ class CaptchaConfig(Config): self.captcha_bypass_secret = config.get("captcha_bypass_secret") self.recaptcha_siteverify_api = config["recaptcha_siteverify_api"] - def default_config(self, config_dir_path, server_name): + def default_config(self, **kwargs): return """\ ## Captcha ## diff --git a/synapse/config/database.py b/synapse/config/database.py index f0611e8884..baeda8f300 100644 --- a/synapse/config/database.py +++ b/synapse/config/database.py @@ -45,7 +45,7 @@ class DatabaseConfig(Config): self.set_databasepath(config.get("database_path")) - def default_config(self, config, config_dir_path): + def default_config(self, **kwargs): database_path = self.abspath("homeserver.db") return """\ # Database configuration diff --git a/synapse/config/key.py b/synapse/config/key.py index 23ac8a3fca..2c187065e5 100644 --- a/synapse/config/key.py +++ b/synapse/config/key.py @@ -40,7 +40,7 @@ class KeyConfig(Config): config["perspectives"] ) - def default_config(self, config_dir_path, server_name): + def default_config(self, config_dir_path, server_name, **kwargs): base_key_name = os.path.join(config_dir_path, server_name) return """\ ## Signing Keys ## diff --git a/synapse/config/logger.py b/synapse/config/logger.py index daca698d0c..bd0c17c861 100644 --- a/synapse/config/logger.py +++ b/synapse/config/logger.py @@ -70,7 +70,7 @@ class LoggingConfig(Config): self.log_config = self.abspath(config.get("log_config")) self.log_file = self.abspath(config.get("log_file")) - def default_config(self, config_dir_path, server_name): + def default_config(self, config_dir_path, server_name, **kwargs): log_file = self.abspath("homeserver.log") log_config = self.abspath( os.path.join(config_dir_path, server_name + ".log.config") diff --git a/synapse/config/metrics.py b/synapse/config/metrics.py index ae5a691527..825fec9a38 100644 --- a/synapse/config/metrics.py +++ b/synapse/config/metrics.py @@ -19,13 +19,15 @@ from ._base import Config class MetricsConfig(Config): def read_config(self, config): self.enable_metrics = config["enable_metrics"] + self.report_stats = config.get("report_stats", None) self.metrics_port = config.get("metrics_port") self.metrics_bind_host = config.get("metrics_bind_host", "127.0.0.1") - def default_config(self, config_dir_path, server_name): - return """\ + def default_config(self, report_stats=None, **kwargs): + suffix = "" if report_stats is None else "report_stats: %(report_stats)s\n" + return ("""\ ## Metrics ### # Enable collection and rendering of performance metrics enable_metrics: False - """ + """ + suffix) % locals() diff --git a/synapse/config/ratelimiting.py b/synapse/config/ratelimiting.py index 76d9970e5b..611b598ec7 100644 --- a/synapse/config/ratelimiting.py +++ b/synapse/config/ratelimiting.py @@ -27,7 +27,7 @@ class RatelimitConfig(Config): self.federation_rc_reject_limit = config["federation_rc_reject_limit"] self.federation_rc_concurrent = config["federation_rc_concurrent"] - def default_config(self, config_dir_path, server_name): + def default_config(self, **kwargs): return """\ ## Ratelimiting ## diff --git a/synapse/config/registration.py b/synapse/config/registration.py index 62de4b399f..fa98eced34 100644 --- a/synapse/config/registration.py +++ b/synapse/config/registration.py @@ -34,7 +34,7 @@ class RegistrationConfig(Config): self.registration_shared_secret = config.get("registration_shared_secret") self.macaroon_secret_key = config.get("macaroon_secret_key") - def default_config(self, config_dir, server_name): + def default_config(self, **kwargs): registration_shared_secret = random_string_with_symbols(50) macaroon_secret_key = random_string_with_symbols(50) return """\ diff --git a/synapse/config/repository.py b/synapse/config/repository.py index 64644b9a7a..2fcf872449 100644 --- a/synapse/config/repository.py +++ b/synapse/config/repository.py @@ -60,7 +60,7 @@ class ContentRepositoryConfig(Config): config["thumbnail_sizes"] ) - def default_config(self, config_dir_path, server_name): + def default_config(self, **kwargs): media_store = self.default_path("media_store") uploads_path = self.default_path("uploads") return """ diff --git a/synapse/config/saml2.py b/synapse/config/saml2.py index 1532036876..4c6133cf22 100644 --- a/synapse/config/saml2.py +++ b/synapse/config/saml2.py @@ -41,7 +41,7 @@ class SAML2Config(Config): self.saml2_config_path = None self.saml2_idp_redirect_url = None - def default_config(self, config_dir_path, server_name): + def default_config(self, config_dir_path, server_name, **kwargs): return """ # Enable SAML2 for registration and login. Uses pysaml2 # config_path: Path to the sp_conf.py configuration file diff --git a/synapse/config/server.py b/synapse/config/server.py index a03e55c223..4d12d49857 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -117,7 +117,7 @@ class ServerConfig(Config): self.content_addr = content_addr - def default_config(self, config_dir_path, server_name): + def default_config(self, server_name, **kwargs): if ":" in server_name: bind_port = int(server_name.split(":")[1]) unsecure_port = bind_port - 400 diff --git a/synapse/config/tls.py b/synapse/config/tls.py index e6023a718d..0ac2698293 100644 --- a/synapse/config/tls.py +++ b/synapse/config/tls.py @@ -50,7 +50,7 @@ class TlsConfig(Config): "use_insecure_ssl_client_just_for_testing_do_not_use" ) - def default_config(self, config_dir_path, server_name): + def default_config(self, config_dir_path, server_name, **kwargs): base_key_name = os.path.join(config_dir_path, server_name) tls_certificate_path = base_key_name + ".tls.crt" diff --git a/synapse/config/voip.py b/synapse/config/voip.py index a1707223d3..a093354ccd 100644 --- a/synapse/config/voip.py +++ b/synapse/config/voip.py @@ -22,7 +22,7 @@ class VoipConfig(Config): self.turn_shared_secret = config["turn_shared_secret"] self.turn_user_lifetime = self.parse_duration(config["turn_user_lifetime"]) - def default_config(self, config_dir_path, server_name): + def default_config(self, **kwargs): return """\ ## Turn ## diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index 77cb1dbd81..b64c90d631 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -54,7 +54,7 @@ logger = logging.getLogger(__name__) # Remember to update this number every time a change is made to database # schema files, so the users will be informed on server restarts. -SCHEMA_VERSION = 23 +SCHEMA_VERSION = 24 dir_path = os.path.abspath(os.path.dirname(__file__)) @@ -126,6 +126,24 @@ class DataStore(RoomMemberStore, RoomStore, lock=False, ) + @defer.inlineCallbacks + def count_daily_users(self): + def _count_users(txn): + txn.execute( + "SELECT COUNT(DISTINCT user_id) AS users" + " FROM user_ips" + " WHERE last_seen > ?", + # This is close enough to a day for our purposes. + (int(self._clock.time_msec()) - (1000 * 60 * 60 * 24),) + ) + rows = self.cursor_to_dict(txn) + if rows: + return rows[0]["users"] + return 0 + + ret = yield self.runInteraction("count_users", _count_users) + defer.returnValue(ret) + def get_user_ip_and_agents(self, user): return self._simple_select_list( table="user_ips", diff --git a/synapse/storage/events.py b/synapse/storage/events.py index 0a477e3122..2b51db9940 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - from _base import SQLBaseStore, _RollbackButIsFineException from twisted.internet import defer, reactor @@ -28,6 +27,7 @@ from canonicaljson import encode_canonical_json from contextlib import contextmanager import logging +import math import ujson as json logger = logging.getLogger(__name__) @@ -905,3 +905,59 @@ class EventsStore(SQLBaseStore): txn.execute(sql, (event.event_id,)) result = txn.fetchone() return result[0] if result else None + + @defer.inlineCallbacks + def count_daily_messages(self): + def _count_messages(txn): + now = self.hs.get_clock().time() + + txn.execute( + "SELECT reported_stream_token, reported_time FROM stats_reporting" + ) + last_reported = self.cursor_to_dict(txn) + + txn.execute( + "SELECT stream_ordering" + " FROM events" + " ORDER BY stream_ordering DESC" + " LIMIT 1" + ) + now_reporting = self.cursor_to_dict(txn) + if not now_reporting: + return None + now_reporting = now_reporting[0]["stream_ordering"] + + txn.execute("DELETE FROM stats_reporting") + txn.execute( + "INSERT INTO stats_reporting" + " (reported_stream_token, reported_time)" + " VALUES (?, ?)", + (now_reporting, now,) + ) + + if not last_reported: + return None + + # Close enough to correct for our purposes. + yesterday = (now - 24 * 60 * 60) + if math.fabs(yesterday - last_reported[0]["reported_time"]) > 60 * 60: + return None + + txn.execute( + "SELECT COUNT(*) as messages" + " FROM events NATURAL JOIN event_json" + " WHERE json like '%m.room.message%'" + " AND stream_ordering > ?" + " AND stream_ordering <= ?", + ( + last_reported[0]["reported_stream_token"], + now_reporting, + ) + ) + rows = self.cursor_to_dict(txn) + if not rows: + return None + return rows[0]["messages"] + + ret = yield self.runInteraction("count_messages", _count_messages) + defer.returnValue(ret) diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py index c9ceb132ae..6d76237658 100644 --- a/synapse/storage/registration.py +++ b/synapse/storage/registration.py @@ -289,3 +289,15 @@ class RegistrationStore(SQLBaseStore): if ret: defer.returnValue(ret['user_id']) defer.returnValue(None) + + @defer.inlineCallbacks + def count_all_users(self): + def _count_users(txn): + txn.execute("SELECT COUNT(*) AS users FROM users") + rows = self.cursor_to_dict(txn) + if rows: + return rows[0]["users"] + return 0 + + ret = yield self.runInteraction("count_users", _count_users) + defer.returnValue(ret) diff --git a/synapse/storage/schema/delta/24/stats_reporting.sql b/synapse/storage/schema/delta/24/stats_reporting.sql new file mode 100644 index 0000000000..e9165d2917 --- /dev/null +++ b/synapse/storage/schema/delta/24/stats_reporting.sql @@ -0,0 +1,22 @@ +/* Copyright 2015 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- Should only ever contain one row +CREATE TABLE IF NOT EXISTS stats_reporting( + -- The stream ordering token which was most recently reported as stats + reported_stream_token INTEGER, + -- The time (seconds since epoch) stats were most recently reported + reported_time BIGINT +); diff --git a/tests/storage/event_injector.py b/tests/storage/event_injector.py new file mode 100644 index 0000000000..42bd8928bd --- /dev/null +++ b/tests/storage/event_injector.py @@ -0,0 +1,81 @@ +# -*- coding: utf-8 -*- +# Copyright 2015 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from tests import unittest +from twisted.internet import defer + +from synapse.api.constants import EventTypes, Membership +from synapse.types import UserID, RoomID + +from tests.utils import setup_test_homeserver + +from mock import Mock + + +class EventInjector: + def __init__(self, hs): + self.hs = hs + self.store = hs.get_datastore() + self.message_handler = hs.get_handlers().message_handler + self.event_builder_factory = hs.get_event_builder_factory() + + @defer.inlineCallbacks + def create_room(self, room): + builder = self.event_builder_factory.new({ + "type": EventTypes.Create, + "room_id": room.to_string(), + "content": {}, + }) + + event, context = yield self.message_handler._create_new_client_event( + builder + ) + + yield self.store.persist_event(event, context) + + @defer.inlineCallbacks + def inject_room_member(self, room, user, membership): + builder = self.event_builder_factory.new({ + "type": EventTypes.Member, + "sender": user.to_string(), + "state_key": user.to_string(), + "room_id": room.to_string(), + "content": {"membership": membership}, + }) + + event, context = yield self.message_handler._create_new_client_event( + builder + ) + + yield self.store.persist_event(event, context) + + defer.returnValue(event) + + @defer.inlineCallbacks + def inject_message(self, room, user, body): + builder = self.event_builder_factory.new({ + "type": EventTypes.Message, + "sender": user.to_string(), + "state_key": user.to_string(), + "room_id": room.to_string(), + "content": {"body": body, "msgtype": u"message"}, + }) + + event, context = yield self.message_handler._create_new_client_event( + builder + ) + + yield self.store.persist_event(event, context) diff --git a/tests/storage/test_events.py b/tests/storage/test_events.py new file mode 100644 index 0000000000..313013009e --- /dev/null +++ b/tests/storage/test_events.py @@ -0,0 +1,116 @@ +# -*- coding: utf-8 -*- +# Copyright 2015 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import uuid +from mock.mock import Mock +from synapse.types import RoomID, UserID + +from tests import unittest +from twisted.internet import defer +from tests.storage.event_injector import EventInjector + +from tests.utils import setup_test_homeserver + + +class EventsStoreTestCase(unittest.TestCase): + + @defer.inlineCallbacks + def setUp(self): + self.hs = yield setup_test_homeserver( + resource_for_federation=Mock(), + http_client=None, + ) + self.store = self.hs.get_datastore() + self.db_pool = self.hs.get_db_pool() + self.message_handler = self.hs.get_handlers().message_handler + self.event_injector = EventInjector(self.hs) + + @defer.inlineCallbacks + def test_count_daily_messages(self): + self.db_pool.runQuery("DELETE FROM stats_reporting") + + self.hs.clock.now = 100 + + # Never reported before, and nothing which could be reported + count = yield self.store.count_daily_messages() + self.assertIsNone(count) + count = yield self.db_pool.runQuery("SELECT COUNT(*) FROM stats_reporting") + self.assertEqual([(0,)], count) + + # Create something to report + room = RoomID.from_string("!abc123:test") + user = UserID.from_string("@raccoonlover:test") + yield self.event_injector.create_room(room) + + self.base_event = yield self._get_last_stream_token() + + yield self.event_injector.inject_message(room, user, "Raccoons are really cute") + + # Never reported before, something could be reported, but isn't because + # it isn't old enough. + count = yield self.store.count_daily_messages() + self.assertIsNone(count) + self._assert_stats_reporting(1, self.hs.clock.now) + + # Already reported yesterday, two new events from today. + yield self.event_injector.inject_message(room, user, "Yeah they are!") + yield self.event_injector.inject_message(room, user, "Incredibly!") + self.hs.clock.now += 60 * 60 * 24 + count = yield self.store.count_daily_messages() + self.assertEqual(2, count) # 2 since yesterday + self._assert_stats_reporting(3, self.hs.clock.now) # 3 ever + + # Last reported too recently. + yield self.event_injector.inject_message(room, user, "Who could disagree?") + self.hs.clock.now += 60 * 60 * 22 + count = yield self.store.count_daily_messages() + self.assertIsNone(count) + self._assert_stats_reporting(4, self.hs.clock.now) + + # Last reported too long ago + yield self.event_injector.inject_message(room, user, "No one.") + self.hs.clock.now += 60 * 60 * 26 + count = yield self.store.count_daily_messages() + self.assertIsNone(count) + self._assert_stats_reporting(5, self.hs.clock.now) + + # And now let's actually report something + yield self.event_injector.inject_message(room, user, "Indeed.") + yield self.event_injector.inject_message(room, user, "Indeed.") + yield self.event_injector.inject_message(room, user, "Indeed.") + # A little over 24 hours is fine :) + self.hs.clock.now += (60 * 60 * 24) + 50 + count = yield self.store.count_daily_messages() + self.assertEqual(3, count) + self._assert_stats_reporting(8, self.hs.clock.now) + + @defer.inlineCallbacks + def _get_last_stream_token(self): + rows = yield self.db_pool.runQuery( + "SELECT stream_ordering" + " FROM events" + " ORDER BY stream_ordering DESC" + " LIMIT 1" + ) + if not rows: + defer.returnValue(0) + else: + defer.returnValue(rows[0][0]) + + @defer.inlineCallbacks + def _assert_stats_reporting(self, messages, time): + rows = yield self.db_pool.runQuery( + "SELECT reported_stream_token, reported_time FROM stats_reporting" + ) + self.assertEqual([(self.base_event + messages, time,)], rows) diff --git a/tests/storage/test_room.py b/tests/storage/test_room.py index ab7625a3ca..caffce64e3 100644 --- a/tests/storage/test_room.py +++ b/tests/storage/test_room.py @@ -85,7 +85,7 @@ class RoomEventsStoreTestCase(unittest.TestCase): # Room events need the full datastore, for persist_event() and # get_room_state() self.store = hs.get_datastore() - self.event_factory = hs.get_event_factory(); + self.event_factory = hs.get_event_factory() self.room = RoomID.from_string("!abcde:test") diff --git a/tests/storage/test_stream.py b/tests/storage/test_stream.py index 0c9b89d765..a658a789aa 100644 --- a/tests/storage/test_stream.py +++ b/tests/storage/test_stream.py @@ -19,6 +19,7 @@ from twisted.internet import defer from synapse.api.constants import EventTypes, Membership from synapse.types import UserID, RoomID +from tests.storage.event_injector import EventInjector from tests.utils import setup_test_homeserver @@ -36,6 +37,7 @@ class StreamStoreTestCase(unittest.TestCase): self.store = hs.get_datastore() self.event_builder_factory = hs.get_event_builder_factory() + self.event_injector = EventInjector(hs) self.handlers = hs.get_handlers() self.message_handler = self.handlers.message_handler @@ -45,60 +47,20 @@ class StreamStoreTestCase(unittest.TestCase): self.room1 = RoomID.from_string("!abc123:test") self.room2 = RoomID.from_string("!xyx987:test") - self.depth = 1 - - @defer.inlineCallbacks - def inject_room_member(self, room, user, membership): - self.depth += 1 - - builder = self.event_builder_factory.new({ - "type": EventTypes.Member, - "sender": user.to_string(), - "state_key": user.to_string(), - "room_id": room.to_string(), - "content": {"membership": membership}, - }) - - event, context = yield self.message_handler._create_new_client_event( - builder - ) - - yield self.store.persist_event(event, context) - - defer.returnValue(event) - - @defer.inlineCallbacks - def inject_message(self, room, user, body): - self.depth += 1 - - builder = self.event_builder_factory.new({ - "type": EventTypes.Message, - "sender": user.to_string(), - "state_key": user.to_string(), - "room_id": room.to_string(), - "content": {"body": body, "msgtype": u"message"}, - }) - - event, context = yield self.message_handler._create_new_client_event( - builder - ) - - yield self.store.persist_event(event, context) - @defer.inlineCallbacks def test_event_stream_get_other(self): # Both bob and alice joins the room - yield self.inject_room_member( + yield self.event_injector.inject_room_member( self.room1, self.u_alice, Membership.JOIN ) - yield self.inject_room_member( + yield self.event_injector.inject_room_member( self.room1, self.u_bob, Membership.JOIN ) # Initial stream key: start = yield self.store.get_room_events_max_id() - yield self.inject_message(self.room1, self.u_alice, u"test") + yield self.event_injector.inject_message(self.room1, self.u_alice, u"test") end = yield self.store.get_room_events_max_id() @@ -125,17 +87,17 @@ class StreamStoreTestCase(unittest.TestCase): @defer.inlineCallbacks def test_event_stream_get_own(self): # Both bob and alice joins the room - yield self.inject_room_member( + yield self.event_injector.inject_room_member( self.room1, self.u_alice, Membership.JOIN ) - yield self.inject_room_member( + yield self.event_injector.inject_room_member( self.room1, self.u_bob, Membership.JOIN ) # Initial stream key: start = yield self.store.get_room_events_max_id() - yield self.inject_message(self.room1, self.u_alice, u"test") + yield self.event_injector.inject_message(self.room1, self.u_alice, u"test") end = yield self.store.get_room_events_max_id() @@ -162,22 +124,22 @@ class StreamStoreTestCase(unittest.TestCase): @defer.inlineCallbacks def test_event_stream_join_leave(self): # Both bob and alice joins the room - yield self.inject_room_member( + yield self.event_injector.inject_room_member( self.room1, self.u_alice, Membership.JOIN ) - yield self.inject_room_member( + yield self.event_injector.inject_room_member( self.room1, self.u_bob, Membership.JOIN ) # Then bob leaves again. - yield self.inject_room_member( + yield self.event_injector.inject_room_member( self.room1, self.u_bob, Membership.LEAVE ) # Initial stream key: start = yield self.store.get_room_events_max_id() - yield self.inject_message(self.room1, self.u_alice, u"test") + yield self.event_injector.inject_message(self.room1, self.u_alice, u"test") end = yield self.store.get_room_events_max_id() @@ -193,17 +155,17 @@ class StreamStoreTestCase(unittest.TestCase): @defer.inlineCallbacks def test_event_stream_prev_content(self): - yield self.inject_room_member( + yield self.event_injector.inject_room_member( self.room1, self.u_bob, Membership.JOIN ) - event1 = yield self.inject_room_member( + event1 = yield self.event_injector.inject_room_member( self.room1, self.u_alice, Membership.JOIN ) start = yield self.store.get_room_events_max_id() - event2 = yield self.inject_room_member( + event2 = yield self.event_injector.inject_room_member( self.room1, self.u_alice, Membership.JOIN, ) -- cgit 1.4.1 From 6d59ffe1ce9a821d50d491f97bf05950198f6f53 Mon Sep 17 00:00:00 2001 From: Daniel Wagner-Hall Date: Tue, 22 Sep 2015 13:47:40 +0100 Subject: Add some docstrings --- synapse/storage/__init__.py | 3 +++ synapse/storage/registration.py | 1 + 2 files changed, 4 insertions(+) (limited to 'synapse/storage/__init__.py') diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index b64c90d631..340e59afcb 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -128,6 +128,9 @@ class DataStore(RoomMemberStore, RoomStore, @defer.inlineCallbacks def count_daily_users(self): + """ + Counts the number of users who used this homeserver in the last 24 hours. + """ def _count_users(txn): txn.execute( "SELECT COUNT(DISTINCT user_id) AS users" diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py index 6d76237658..b454dd5b3a 100644 --- a/synapse/storage/registration.py +++ b/synapse/storage/registration.py @@ -292,6 +292,7 @@ class RegistrationStore(SQLBaseStore): @defer.inlineCallbacks def count_all_users(self): + """Counts all users registered on the homeserver.""" def _count_users(txn): txn.execute("SELECT COUNT(*) AS users FROM users") rows = self.cursor_to_dict(txn) -- cgit 1.4.1 From c85c9125627a62c73711786723be12be30d7a81e Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 9 Oct 2015 15:48:31 +0100 Subject: Add basic full text search impl. --- synapse/api/constants.py | 19 +++++++ synapse/handlers/__init__.py | 2 + synapse/handlers/search.py | 95 ++++++++++++++++++++++++++++++++++ synapse/rest/client/v1/room.py | 17 ++++++ synapse/storage/__init__.py | 2 + synapse/storage/_base.py | 2 +- synapse/storage/schema/delta/24/fts.py | 57 ++++++++++++++++++++ synapse/storage/search.py | 75 +++++++++++++++++++++++++++ 8 files changed, 268 insertions(+), 1 deletion(-) create mode 100644 synapse/handlers/search.py create mode 100644 synapse/storage/schema/delta/24/fts.py create mode 100644 synapse/storage/search.py (limited to 'synapse/storage/__init__.py') diff --git a/synapse/api/constants.py b/synapse/api/constants.py index 008ee64727..7c7f9ff957 100644 --- a/synapse/api/constants.py +++ b/synapse/api/constants.py @@ -84,3 +84,22 @@ class RoomCreationPreset(object): PRIVATE_CHAT = "private_chat" PUBLIC_CHAT = "public_chat" TRUSTED_PRIVATE_CHAT = "trusted_private_chat" + + +class SearchConstraintTypes(object): + FTS = "fts" + EXACT = "exact" + PREFIX = "prefix" + SUBSTRING = "substring" + RANGE = "range" + + +class KnownRoomEventKeys(object): + CONTENT_BODY = "content.body" + CONTENT_MSGTYPE = "content.msgtype" + CONTENT_NAME = "content.name" + CONTENT_TOPIC = "content.topic" + + SENDER = "sender" + ORIGIN_SERVER_TS = "origin_server_ts" + ROOM_ID = "room_id" diff --git a/synapse/handlers/__init__.py b/synapse/handlers/__init__.py index 8725c3c420..87b4d381c7 100644 --- a/synapse/handlers/__init__.py +++ b/synapse/handlers/__init__.py @@ -32,6 +32,7 @@ from .sync import SyncHandler from .auth import AuthHandler from .identity import IdentityHandler from .receipts import ReceiptsHandler +from .search import SearchHandler class Handlers(object): @@ -68,3 +69,4 @@ class Handlers(object): self.sync_handler = SyncHandler(hs) self.auth_handler = AuthHandler(hs) self.identity_handler = IdentityHandler(hs) + self.search_handler = SearchHandler(hs) diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py new file mode 100644 index 0000000000..8b997fc394 --- /dev/null +++ b/synapse/handlers/search.py @@ -0,0 +1,95 @@ +# -*- coding: utf-8 -*- +# Copyright 2015 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from twisted.internet import defer + +from ._base import BaseHandler + +from synapse.api.constants import KnownRoomEventKeys, SearchConstraintTypes +from synapse.api.errors import SynapseError +from synapse.events.utils import serialize_event + +import logging + + +logger = logging.getLogger(__name__) + + +KEYS_TO_ALLOWED_CONSTRAINT_TYPES = { + KnownRoomEventKeys.CONTENT_BODY: [SearchConstraintTypes.FTS], + KnownRoomEventKeys.CONTENT_MSGTYPE: [SearchConstraintTypes.EXACT], + KnownRoomEventKeys.CONTENT_NAME: [SearchConstraintTypes.FTS, SearchConstraintTypes.EXACT, SearchConstraintTypes.SUBSTRING], + KnownRoomEventKeys.CONTENT_TOPIC: [SearchConstraintTypes.FTS], + KnownRoomEventKeys.SENDER: [SearchConstraintTypes.EXACT], + KnownRoomEventKeys.ORIGIN_SERVER_TS: [SearchConstraintTypes.RANGE], + KnownRoomEventKeys.ROOM_ID: [SearchConstraintTypes.EXACT], +} + + +class RoomConstraint(object): + def __init__(self, search_type, keys, value): + self.search_type = search_type + self.keys = keys + self.value = value + + @classmethod + def from_dict(cls, d): + search_type = d["type"] + keys = d["keys"] + + for key in keys: + if key not in KEYS_TO_ALLOWED_CONSTRAINT_TYPES: + raise SynapseError(400, "Unrecognized key %r", key) + + if search_type not in KEYS_TO_ALLOWED_CONSTRAINT_TYPES[key]: + raise SynapseError(400, "Disallowed constraint type %r for key %r", search_type, key) + + return cls(search_type, keys, d["value"]) + + +class SearchHandler(BaseHandler): + + def __init__(self, hs): + super(SearchHandler, self).__init__(hs) + + @defer.inlineCallbacks + def search(self, content): + constraint_dicts = content["search_categories"]["room_events"]["constraints"] + constraints = [RoomConstraint.from_dict(c)for c in constraint_dicts] + + fts = False + for c in constraints: + if c.search_type == SearchConstraintTypes.FTS: + if fts: + raise SynapseError(400, "Only one constraint can be FTS") + fts = True + + res = yield self.hs.get_datastore().search_msgs(constraints) + + time_now = self.hs.get_clock().time_msec() + + results = [ + { + "rank": r["rank"], + "result": serialize_event(r["result"], time_now) + } + for r in res + ] + + logger.info("returning: %r", results) + + results.sort(key=lambda r: -r["rank"]) + + defer.returnValue(results) diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index 23871f161e..35bd702a43 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -529,6 +529,22 @@ class RoomTypingRestServlet(ClientV1RestServlet): defer.returnValue((200, {})) +class SearchRestServlet(ClientV1RestServlet): + PATTERN = client_path_pattern( + "/search$" + ) + + @defer.inlineCallbacks + def on_POST(self, request): + auth_user, _ = yield self.auth.get_user_by_req(request) + + content = _parse_json(request) + + results = yield self.handlers.search_handler.search(content) + + defer.returnValue((200, results)) + + def _parse_json(request): try: content = json.loads(request.content.read()) @@ -585,3 +601,4 @@ def register_servlets(hs, http_server): RoomInitialSyncRestServlet(hs).register(http_server) RoomRedactEventRestServlet(hs).register(http_server) RoomTypingRestServlet(hs).register(http_server) + SearchRestServlet(hs).register(http_server) diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index 340e59afcb..5f91ef77c0 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -40,6 +40,7 @@ from .filtering import FilteringStore from .end_to_end_keys import EndToEndKeyStore from .receipts import ReceiptsStore +from .search import SearchStore import fnmatch @@ -79,6 +80,7 @@ class DataStore(RoomMemberStore, RoomStore, EventsStore, ReceiptsStore, EndToEndKeyStore, + SearchStore, ): def __init__(self, hs): diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 693784ad38..218e708054 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -519,7 +519,7 @@ class SQLBaseStore(object): allow_none=False, desc="_simple_select_one_onecol"): """Executes a SELECT query on the named table, which is expected to - return a single row, returning a single column from it." + return a single row, returning a single column from it. Args: table : string giving the table name diff --git a/synapse/storage/schema/delta/24/fts.py b/synapse/storage/schema/delta/24/fts.py new file mode 100644 index 0000000000..5680332758 --- /dev/null +++ b/synapse/storage/schema/delta/24/fts.py @@ -0,0 +1,57 @@ +# Copyright 2015 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from synapse.storage import get_statements +from synapse.storage.engines import PostgresEngine + +logger = logging.getLogger(__name__) + + +POSTGRES_SQL = """ +CREATE TABLE event_search ( + event_id TEXT, + room_id TEXT, + key TEXT, + vector tsvector +); + +INSERT INTO event_search SELECT + event_id, room_id, 'content.body', + to_tsvector('english', json::json->'content'->>'body') + FROM events NATURAL JOIN event_json WHERE type = 'm.room.message'; + +INSERT INTO event_search SELECT + event_id, room_id, 'content.name', + to_tsvector('english', json::json->'content'->>'name') + FROM events NATURAL JOIN event_json WHERE type = 'm.room.name'; + +INSERT INTO event_search SELECT + event_id, room_id, 'content.topic', + to_tsvector('english', json::json->'content'->>'topic') + FROM events NATURAL JOIN event_json WHERE type = 'm.room.topic'; + + +CREATE INDEX event_search_idx ON event_search USING gin(vector); +""" + + +def run_upgrade(cur, database_engine, *args, **kwargs): + if not isinstance(database_engine, PostgresEngine): + # We only support FTS for postgres currently. + return + + for statement in get_statements(POSTGRES_SQL.splitlines()): + cur.execute(statement) diff --git a/synapse/storage/search.py b/synapse/storage/search.py new file mode 100644 index 0000000000..eea4477765 --- /dev/null +++ b/synapse/storage/search.py @@ -0,0 +1,75 @@ +# -*- coding: utf-8 -*- +# Copyright 2015 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from twisted.internet import defer + +from _base import SQLBaseStore +from synapse.api.constants import KnownRoomEventKeys, SearchConstraintTypes + + +class SearchStore(SQLBaseStore): + @defer.inlineCallbacks + def search_msgs(self, constraints): + clauses = [] + args = [] + fts = None + + for c in constraints: + local_clauses = [] + if c.search_type == SearchConstraintTypes.FTS: + fts = c.value + for key in c.keys: + local_clauses.append("key = ?") + args.append(key) + elif c.search_type == SearchConstraintTypes.EXACT: + for key in c.keys: + if key == KnownRoomEventKeys.ROOM_ID: + for value in c.value: + local_clauses.append("room_id = ?") + args.append(value) + clauses.append( + "(%s)" % (" OR ".join(local_clauses),) + ) + + sql = ( + "SELECT ts_rank_cd(vector, query) AS rank, event_id" + " FROM plainto_tsquery('english', ?) as query, event_search" + " WHERE vector @@ query" + ) + + for clause in clauses: + sql += " AND " + clause + + sql += " ORDER BY rank DESC" + + results = yield self._execute( + "search_msgs", self.cursor_to_dict, sql, *([fts] + args) + ) + + events = yield self._get_events([r["event_id"] for r in results]) + + event_map = { + ev.event_id: ev + for ev in events + } + + defer.returnValue([ + { + "rank": r["rank"], + "result": event_map[r["event_id"]] + } + for r in results + if r["event_id"] in event_map + ]) -- cgit 1.4.1 From 40b6a5aad1309fed9d1e32be387798dd46b2cf4f Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 13 Oct 2015 11:38:48 +0100 Subject: Split out the schema preparation and update logic into its own module --- synapse/storage/__init__.py | 378 +--------------------------------- synapse/storage/_schema_prepare.py | 395 ++++++++++++++++++++++++++++++++++++ synapse/storage/engines/postgres.py | 2 +- synapse/storage/engines/sqlite3.py | 4 +- 4 files changed, 402 insertions(+), 377 deletions(-) create mode 100644 synapse/storage/_schema_prepare.py (limited to 'synapse/storage/__init__.py') diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index 340e59afcb..4be629bff8 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -41,23 +41,16 @@ from .end_to_end_keys import EndToEndKeyStore from .receipts import ReceiptsStore +from ._schema_prepare import UpgradeDatabaseException + +__all__ = [UpgradeDatabaseException] -import fnmatch -import imp import logging -import os -import re logger = logging.getLogger(__name__) -# Remember to update this number every time a change is made to database -# schema files, so the users will be informed on server restarts. -SCHEMA_VERSION = 24 - -dir_path = os.path.abspath(os.path.dirname(__file__)) - # Number of msec of granularity to store the user IP 'last seen' time. Smaller # times give more inserts into the database even for readonly API hits # 120 seconds == 2 minutes @@ -158,371 +151,6 @@ class DataStore(RoomMemberStore, RoomStore, ) -def read_schema(path): - """ Read the named database schema. - - Args: - path: Path of the database schema. - Returns: - A string containing the database schema. - """ - with open(path) as schema_file: - return schema_file.read() - - -class PrepareDatabaseException(Exception): - pass - - -class UpgradeDatabaseException(PrepareDatabaseException): - pass - - -def prepare_database(db_conn, database_engine): - """Prepares a database for usage. Will either create all necessary tables - or upgrade from an older schema version. - """ - try: - cur = db_conn.cursor() - version_info = _get_or_create_schema_state(cur, database_engine) - - if version_info: - user_version, delta_files, upgraded = version_info - _upgrade_existing_database( - cur, user_version, delta_files, upgraded, database_engine - ) - else: - _setup_new_database(cur, database_engine) - - # cur.execute("PRAGMA user_version = %d" % (SCHEMA_VERSION,)) - - cur.close() - db_conn.commit() - except: - db_conn.rollback() - raise - - -def _setup_new_database(cur, database_engine): - """Sets up the database by finding a base set of "full schemas" and then - applying any necessary deltas. - - The "full_schemas" directory has subdirectories named after versions. This - function searches for the highest version less than or equal to - `SCHEMA_VERSION` and executes all .sql files in that directory. - - The function will then apply all deltas for all versions after the base - version. - - Example directory structure: - - schema/ - delta/ - ... - full_schemas/ - 3/ - test.sql - ... - 11/ - foo.sql - bar.sql - ... - - In the example foo.sql and bar.sql would be run, and then any delta files - for versions strictly greater than 11. - """ - current_dir = os.path.join(dir_path, "schema", "full_schemas") - directory_entries = os.listdir(current_dir) - - valid_dirs = [] - pattern = re.compile(r"^\d+(\.sql)?$") - for filename in directory_entries: - match = pattern.match(filename) - abs_path = os.path.join(current_dir, filename) - if match and os.path.isdir(abs_path): - ver = int(match.group(0)) - if ver <= SCHEMA_VERSION: - valid_dirs.append((ver, abs_path)) - else: - logger.warn("Unexpected entry in 'full_schemas': %s", filename) - - if not valid_dirs: - raise PrepareDatabaseException( - "Could not find a suitable base set of full schemas" - ) - - max_current_ver, sql_dir = max(valid_dirs, key=lambda x: x[0]) - - logger.debug("Initialising schema v%d", max_current_ver) - - directory_entries = os.listdir(sql_dir) - - for filename in fnmatch.filter(directory_entries, "*.sql"): - sql_loc = os.path.join(sql_dir, filename) - logger.debug("Applying schema %s", sql_loc) - executescript(cur, sql_loc) - - cur.execute( - database_engine.convert_param_style( - "INSERT INTO schema_version (version, upgraded)" - " VALUES (?,?)" - ), - (max_current_ver, False,) - ) - - _upgrade_existing_database( - cur, - current_version=max_current_ver, - applied_delta_files=[], - upgraded=False, - database_engine=database_engine, - ) - - -def _upgrade_existing_database(cur, current_version, applied_delta_files, - upgraded, database_engine): - """Upgrades an existing database. - - Delta files can either be SQL stored in *.sql files, or python modules - in *.py. - - There can be multiple delta files per version. Synapse will keep track of - which delta files have been applied, and will apply any that haven't been - even if there has been no version bump. This is useful for development - where orthogonal schema changes may happen on separate branches. - - Different delta files for the same version *must* be orthogonal and give - the same result when applied in any order. No guarantees are made on the - order of execution of these scripts. - - This is a no-op of current_version == SCHEMA_VERSION. - - Example directory structure: - - schema/ - delta/ - 11/ - foo.sql - ... - 12/ - foo.sql - bar.py - ... - full_schemas/ - ... - - In the example, if current_version is 11, then foo.sql will be run if and - only if `upgraded` is True. Then `foo.sql` and `bar.py` would be run in - some arbitrary order. - - Args: - cur (Cursor) - current_version (int): The current version of the schema. - applied_delta_files (list): A list of deltas that have already been - applied. - upgraded (bool): Whether the current version was generated by having - applied deltas or from full schema file. If `True` the function - will never apply delta files for the given `current_version`, since - the current_version wasn't generated by applying those delta files. - """ - - if current_version > SCHEMA_VERSION: - raise ValueError( - "Cannot use this database as it is too " + - "new for the server to understand" - ) - - start_ver = current_version - if not upgraded: - start_ver += 1 - - logger.debug("applied_delta_files: %s", applied_delta_files) - - for v in range(start_ver, SCHEMA_VERSION + 1): - logger.debug("Upgrading schema to v%d", v) - - delta_dir = os.path.join(dir_path, "schema", "delta", str(v)) - - try: - directory_entries = os.listdir(delta_dir) - except OSError: - logger.exception("Could not open delta dir for version %d", v) - raise UpgradeDatabaseException( - "Could not open delta dir for version %d" % (v,) - ) - - directory_entries.sort() - for file_name in directory_entries: - relative_path = os.path.join(str(v), file_name) - logger.debug("Found file: %s", relative_path) - if relative_path in applied_delta_files: - continue - - absolute_path = os.path.join( - dir_path, "schema", "delta", relative_path, - ) - root_name, ext = os.path.splitext(file_name) - if ext == ".py": - # This is a python upgrade module. We need to import into some - # package and then execute its `run_upgrade` function. - module_name = "synapse.storage.v%d_%s" % ( - v, root_name - ) - with open(absolute_path) as python_file: - module = imp.load_source( - module_name, absolute_path, python_file - ) - logger.debug("Running script %s", relative_path) - module.run_upgrade(cur, database_engine) - elif ext == ".pyc": - # Sometimes .pyc files turn up anyway even though we've - # disabled their generation; e.g. from distribution package - # installers. Silently skip it - pass - elif ext == ".sql": - # A plain old .sql file, just read and execute it - logger.debug("Applying schema %s", relative_path) - executescript(cur, absolute_path) - else: - # Not a valid delta file. - logger.warn( - "Found directory entry that did not end in .py or" - " .sql: %s", - relative_path, - ) - continue - - # Mark as done. - cur.execute( - database_engine.convert_param_style( - "INSERT INTO applied_schema_deltas (version, file)" - " VALUES (?,?)", - ), - (v, relative_path) - ) - - cur.execute("DELETE FROM schema_version") - cur.execute( - database_engine.convert_param_style( - "INSERT INTO schema_version (version, upgraded)" - " VALUES (?,?)", - ), - (v, True) - ) - - -def get_statements(f): - statement_buffer = "" - in_comment = False # If we're in a /* ... */ style comment - - for line in f: - line = line.strip() - - if in_comment: - # Check if this line contains an end to the comment - comments = line.split("*/", 1) - if len(comments) == 1: - continue - line = comments[1] - in_comment = False - - # Remove inline block comments - line = re.sub(r"/\*.*\*/", " ", line) - - # Does this line start a comment? - comments = line.split("/*", 1) - if len(comments) > 1: - line = comments[0] - in_comment = True - - # Deal with line comments - line = line.split("--", 1)[0] - line = line.split("//", 1)[0] - - # Find *all* semicolons. We need to treat first and last entry - # specially. - statements = line.split(";") - - # We must prepend statement_buffer to the first statement - first_statement = "%s %s" % ( - statement_buffer.strip(), - statements[0].strip() - ) - statements[0] = first_statement - - # Every entry, except the last, is a full statement - for statement in statements[:-1]: - yield statement.strip() - - # The last entry did *not* end in a semicolon, so we store it for the - # next semicolon we find - statement_buffer = statements[-1].strip() - - -def executescript(txn, schema_path): - with open(schema_path, 'r') as f: - for statement in get_statements(f): - txn.execute(statement) - - -def _get_or_create_schema_state(txn, database_engine): - # Bluntly try creating the schema_version tables. - schema_path = os.path.join( - dir_path, "schema", "schema_version.sql", - ) - executescript(txn, schema_path) - - txn.execute("SELECT version, upgraded FROM schema_version") - row = txn.fetchone() - current_version = int(row[0]) if row else None - upgraded = bool(row[1]) if row else None - - if current_version: - txn.execute( - database_engine.convert_param_style( - "SELECT file FROM applied_schema_deltas WHERE version >= ?" - ), - (current_version,) - ) - applied_deltas = [d for d, in txn.fetchall()] - return current_version, applied_deltas, upgraded - - return None - - -def prepare_sqlite3_database(db_conn): - """This function should be called before `prepare_database` on sqlite3 - databases. - - Since we changed the way we store the current schema version and handle - updates to schemas, we need a way to upgrade from the old method to the - new. This only affects sqlite databases since they were the only ones - supported at the time. - """ - with db_conn: - schema_path = os.path.join( - dir_path, "schema", "schema_version.sql", - ) - create_schema = read_schema(schema_path) - db_conn.executescript(create_schema) - - c = db_conn.execute("SELECT * FROM schema_version") - rows = c.fetchall() - c.close() - - if not rows: - c = db_conn.execute("PRAGMA user_version") - row = c.fetchone() - c.close() - - if row and row[0]: - db_conn.execute( - "REPLACE INTO schema_version (version, upgraded)" - " VALUES (?,?)", - (row[0], False) - ) - - def are_all_users_on_domain(txn, database_engine, domain): sql = database_engine.convert_param_style( "SELECT COUNT(*) FROM users WHERE name NOT LIKE ?" diff --git a/synapse/storage/_schema_prepare.py b/synapse/storage/_schema_prepare.py new file mode 100644 index 0000000000..1ddf55be4d --- /dev/null +++ b/synapse/storage/_schema_prepare.py @@ -0,0 +1,395 @@ +# -*- coding: utf-8 -*- +# Copyright 2014, 2015 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import fnmatch +import imp +import logging +import os +import re + + +logger = logging.getLogger(__name__) + + +# Remember to update this number every time a change is made to database +# schema files, so the users will be informed on server restarts. +SCHEMA_VERSION = 24 + +dir_path = os.path.abspath(os.path.dirname(__file__)) + + +def read_schema(path): + """ Read the named database schema. + + Args: + path: Path of the database schema. + Returns: + A string containing the database schema. + """ + with open(path) as schema_file: + return schema_file.read() + + +class PrepareDatabaseException(Exception): + pass + + +class UpgradeDatabaseException(PrepareDatabaseException): + pass + + +def prepare_database(db_conn, database_engine): + """Prepares a database for usage. Will either create all necessary tables + or upgrade from an older schema version. + """ + try: + cur = db_conn.cursor() + version_info = _get_or_create_schema_state(cur, database_engine) + + if version_info: + user_version, delta_files, upgraded = version_info + _upgrade_existing_database( + cur, user_version, delta_files, upgraded, database_engine + ) + else: + _setup_new_database(cur, database_engine) + + # cur.execute("PRAGMA user_version = %d" % (SCHEMA_VERSION,)) + + cur.close() + db_conn.commit() + except: + db_conn.rollback() + raise + + +def _setup_new_database(cur, database_engine): + """Sets up the database by finding a base set of "full schemas" and then + applying any necessary deltas. + + The "full_schemas" directory has subdirectories named after versions. This + function searches for the highest version less than or equal to + `SCHEMA_VERSION` and executes all .sql files in that directory. + + The function will then apply all deltas for all versions after the base + version. + + Example directory structure: + + schema/ + delta/ + ... + full_schemas/ + 3/ + test.sql + ... + 11/ + foo.sql + bar.sql + ... + + In the example foo.sql and bar.sql would be run, and then any delta files + for versions strictly greater than 11. + """ + current_dir = os.path.join(dir_path, "schema", "full_schemas") + directory_entries = os.listdir(current_dir) + + valid_dirs = [] + pattern = re.compile(r"^\d+(\.sql)?$") + for filename in directory_entries: + match = pattern.match(filename) + abs_path = os.path.join(current_dir, filename) + if match and os.path.isdir(abs_path): + ver = int(match.group(0)) + if ver <= SCHEMA_VERSION: + valid_dirs.append((ver, abs_path)) + else: + logger.warn("Unexpected entry in 'full_schemas': %s", filename) + + if not valid_dirs: + raise PrepareDatabaseException( + "Could not find a suitable base set of full schemas" + ) + + max_current_ver, sql_dir = max(valid_dirs, key=lambda x: x[0]) + + logger.debug("Initialising schema v%d", max_current_ver) + + directory_entries = os.listdir(sql_dir) + + for filename in fnmatch.filter(directory_entries, "*.sql"): + sql_loc = os.path.join(sql_dir, filename) + logger.debug("Applying schema %s", sql_loc) + executescript(cur, sql_loc) + + cur.execute( + database_engine.convert_param_style( + "INSERT INTO schema_version (version, upgraded)" + " VALUES (?,?)" + ), + (max_current_ver, False,) + ) + + _upgrade_existing_database( + cur, + current_version=max_current_ver, + applied_delta_files=[], + upgraded=False, + database_engine=database_engine, + ) + + +def _upgrade_existing_database(cur, current_version, applied_delta_files, + upgraded, database_engine): + """Upgrades an existing database. + + Delta files can either be SQL stored in *.sql files, or python modules + in *.py. + + There can be multiple delta files per version. Synapse will keep track of + which delta files have been applied, and will apply any that haven't been + even if there has been no version bump. This is useful for development + where orthogonal schema changes may happen on separate branches. + + Different delta files for the same version *must* be orthogonal and give + the same result when applied in any order. No guarantees are made on the + order of execution of these scripts. + + This is a no-op of current_version == SCHEMA_VERSION. + + Example directory structure: + + schema/ + delta/ + 11/ + foo.sql + ... + 12/ + foo.sql + bar.py + ... + full_schemas/ + ... + + In the example, if current_version is 11, then foo.sql will be run if and + only if `upgraded` is True. Then `foo.sql` and `bar.py` would be run in + some arbitrary order. + + Args: + cur (Cursor) + current_version (int): The current version of the schema. + applied_delta_files (list): A list of deltas that have already been + applied. + upgraded (bool): Whether the current version was generated by having + applied deltas or from full schema file. If `True` the function + will never apply delta files for the given `current_version`, since + the current_version wasn't generated by applying those delta files. + """ + + if current_version > SCHEMA_VERSION: + raise ValueError( + "Cannot use this database as it is too " + + "new for the server to understand" + ) + + start_ver = current_version + if not upgraded: + start_ver += 1 + + logger.debug("applied_delta_files: %s", applied_delta_files) + + for v in range(start_ver, SCHEMA_VERSION + 1): + logger.debug("Upgrading schema to v%d", v) + + delta_dir = os.path.join(dir_path, "schema", "delta", str(v)) + + try: + directory_entries = os.listdir(delta_dir) + except OSError: + logger.exception("Could not open delta dir for version %d", v) + raise UpgradeDatabaseException( + "Could not open delta dir for version %d" % (v,) + ) + + directory_entries.sort() + for file_name in directory_entries: + relative_path = os.path.join(str(v), file_name) + logger.debug("Found file: %s", relative_path) + if relative_path in applied_delta_files: + continue + + absolute_path = os.path.join( + dir_path, "schema", "delta", relative_path, + ) + root_name, ext = os.path.splitext(file_name) + if ext == ".py": + # This is a python upgrade module. We need to import into some + # package and then execute its `run_upgrade` function. + module_name = "synapse.storage.v%d_%s" % ( + v, root_name + ) + with open(absolute_path) as python_file: + module = imp.load_source( + module_name, absolute_path, python_file + ) + logger.debug("Running script %s", relative_path) + module.run_upgrade(cur, database_engine) + elif ext == ".pyc": + # Sometimes .pyc files turn up anyway even though we've + # disabled their generation; e.g. from distribution package + # installers. Silently skip it + pass + elif ext == ".sql": + # A plain old .sql file, just read and execute it + logger.debug("Applying schema %s", relative_path) + executescript(cur, absolute_path) + else: + # Not a valid delta file. + logger.warn( + "Found directory entry that did not end in .py or" + " .sql: %s", + relative_path, + ) + continue + + # Mark as done. + cur.execute( + database_engine.convert_param_style( + "INSERT INTO applied_schema_deltas (version, file)" + " VALUES (?,?)", + ), + (v, relative_path) + ) + + cur.execute("DELETE FROM schema_version") + cur.execute( + database_engine.convert_param_style( + "INSERT INTO schema_version (version, upgraded)" + " VALUES (?,?)", + ), + (v, True) + ) + + +def get_statements(f): + statement_buffer = "" + in_comment = False # If we're in a /* ... */ style comment + + for line in f: + line = line.strip() + + if in_comment: + # Check if this line contains an end to the comment + comments = line.split("*/", 1) + if len(comments) == 1: + continue + line = comments[1] + in_comment = False + + # Remove inline block comments + line = re.sub(r"/\*.*\*/", " ", line) + + # Does this line start a comment? + comments = line.split("/*", 1) + if len(comments) > 1: + line = comments[0] + in_comment = True + + # Deal with line comments + line = line.split("--", 1)[0] + line = line.split("//", 1)[0] + + # Find *all* semicolons. We need to treat first and last entry + # specially. + statements = line.split(";") + + # We must prepend statement_buffer to the first statement + first_statement = "%s %s" % ( + statement_buffer.strip(), + statements[0].strip() + ) + statements[0] = first_statement + + # Every entry, except the last, is a full statement + for statement in statements[:-1]: + yield statement.strip() + + # The last entry did *not* end in a semicolon, so we store it for the + # next semicolon we find + statement_buffer = statements[-1].strip() + + +def executescript(txn, schema_path): + with open(schema_path, 'r') as f: + for statement in get_statements(f): + txn.execute(statement) + + +def _get_or_create_schema_state(txn, database_engine): + # Bluntly try creating the schema_version tables. + schema_path = os.path.join( + dir_path, "schema", "schema_version.sql", + ) + executescript(txn, schema_path) + + txn.execute("SELECT version, upgraded FROM schema_version") + row = txn.fetchone() + current_version = int(row[0]) if row else None + upgraded = bool(row[1]) if row else None + + if current_version: + txn.execute( + database_engine.convert_param_style( + "SELECT file FROM applied_schema_deltas WHERE version >= ?" + ), + (current_version,) + ) + applied_deltas = [d for d, in txn.fetchall()] + return current_version, applied_deltas, upgraded + + return None + + +def prepare_sqlite3_database(db_conn): + """This function should be called before `prepare_database` on sqlite3 + databases. + + Since we changed the way we store the current schema version and handle + updates to schemas, we need a way to upgrade from the old method to the + new. This only affects sqlite databases since they were the only ones + supported at the time. + """ + with db_conn: + schema_path = os.path.join( + dir_path, "schema", "schema_version.sql", + ) + create_schema = read_schema(schema_path) + db_conn.executescript(create_schema) + + c = db_conn.execute("SELECT * FROM schema_version") + rows = c.fetchall() + c.close() + + if not rows: + c = db_conn.execute("PRAGMA user_version") + row = c.fetchone() + c.close() + + if row and row[0]: + db_conn.execute( + "REPLACE INTO schema_version (version, upgraded)" + " VALUES (?,?)", + (row[0], False) + ) diff --git a/synapse/storage/engines/postgres.py b/synapse/storage/engines/postgres.py index 4a855ffd56..949396044e 100644 --- a/synapse/storage/engines/postgres.py +++ b/synapse/storage/engines/postgres.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.storage import prepare_database +from synapse.storage._schema_prepare import prepare_database from ._base import IncorrectDatabaseSetup diff --git a/synapse/storage/engines/sqlite3.py b/synapse/storage/engines/sqlite3.py index d18e2808d1..a66815ef2d 100644 --- a/synapse/storage/engines/sqlite3.py +++ b/synapse/storage/engines/sqlite3.py @@ -13,7 +13,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.storage import prepare_database, prepare_sqlite3_database +from synapse.storage._schema_prepare import ( + prepare_database, prepare_sqlite3_database +) class Sqlite3Engine(object): -- cgit 1.4.1 From ec398af41c4d276abb02279efbcbb0aa08a4cbc8 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 13 Oct 2015 11:41:04 +0100 Subject: Expose error more nicely --- synapse/app/homeserver.py | 5 +- synapse/storage/__init__.py | 3 - synapse/storage/_schema_prepare.py | 395 ------------------------------------ synapse/storage/engines/postgres.py | 2 +- synapse/storage/engines/sqlite3.py | 2 +- synapse/storage/schema_prepare.py | 395 ++++++++++++++++++++++++++++++++++++ tests/utils.py | 2 +- 7 files changed, 400 insertions(+), 404 deletions(-) delete mode 100644 synapse/storage/_schema_prepare.py create mode 100644 synapse/storage/schema_prepare.py (limited to 'synapse/storage/__init__.py') diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index 190b03e2f7..b284d07cf0 100755 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -35,9 +35,8 @@ if __name__ == '__main__': from synapse.storage.engines import create_engine, IncorrectDatabaseSetup -from synapse.storage import ( - are_all_users_on_domain, UpgradeDatabaseException, -) +from synapse.storage import are_all_users_on_domain +from synapse.storage.schema_prepare import UpgradeDatabaseException from synapse.server import HomeServer diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index 4be629bff8..48a0633746 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -41,9 +41,6 @@ from .end_to_end_keys import EndToEndKeyStore from .receipts import ReceiptsStore -from ._schema_prepare import UpgradeDatabaseException - -__all__ = [UpgradeDatabaseException] import logging diff --git a/synapse/storage/_schema_prepare.py b/synapse/storage/_schema_prepare.py deleted file mode 100644 index 1ddf55be4d..0000000000 --- a/synapse/storage/_schema_prepare.py +++ /dev/null @@ -1,395 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2014, 2015 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import fnmatch -import imp -import logging -import os -import re - - -logger = logging.getLogger(__name__) - - -# Remember to update this number every time a change is made to database -# schema files, so the users will be informed on server restarts. -SCHEMA_VERSION = 24 - -dir_path = os.path.abspath(os.path.dirname(__file__)) - - -def read_schema(path): - """ Read the named database schema. - - Args: - path: Path of the database schema. - Returns: - A string containing the database schema. - """ - with open(path) as schema_file: - return schema_file.read() - - -class PrepareDatabaseException(Exception): - pass - - -class UpgradeDatabaseException(PrepareDatabaseException): - pass - - -def prepare_database(db_conn, database_engine): - """Prepares a database for usage. Will either create all necessary tables - or upgrade from an older schema version. - """ - try: - cur = db_conn.cursor() - version_info = _get_or_create_schema_state(cur, database_engine) - - if version_info: - user_version, delta_files, upgraded = version_info - _upgrade_existing_database( - cur, user_version, delta_files, upgraded, database_engine - ) - else: - _setup_new_database(cur, database_engine) - - # cur.execute("PRAGMA user_version = %d" % (SCHEMA_VERSION,)) - - cur.close() - db_conn.commit() - except: - db_conn.rollback() - raise - - -def _setup_new_database(cur, database_engine): - """Sets up the database by finding a base set of "full schemas" and then - applying any necessary deltas. - - The "full_schemas" directory has subdirectories named after versions. This - function searches for the highest version less than or equal to - `SCHEMA_VERSION` and executes all .sql files in that directory. - - The function will then apply all deltas for all versions after the base - version. - - Example directory structure: - - schema/ - delta/ - ... - full_schemas/ - 3/ - test.sql - ... - 11/ - foo.sql - bar.sql - ... - - In the example foo.sql and bar.sql would be run, and then any delta files - for versions strictly greater than 11. - """ - current_dir = os.path.join(dir_path, "schema", "full_schemas") - directory_entries = os.listdir(current_dir) - - valid_dirs = [] - pattern = re.compile(r"^\d+(\.sql)?$") - for filename in directory_entries: - match = pattern.match(filename) - abs_path = os.path.join(current_dir, filename) - if match and os.path.isdir(abs_path): - ver = int(match.group(0)) - if ver <= SCHEMA_VERSION: - valid_dirs.append((ver, abs_path)) - else: - logger.warn("Unexpected entry in 'full_schemas': %s", filename) - - if not valid_dirs: - raise PrepareDatabaseException( - "Could not find a suitable base set of full schemas" - ) - - max_current_ver, sql_dir = max(valid_dirs, key=lambda x: x[0]) - - logger.debug("Initialising schema v%d", max_current_ver) - - directory_entries = os.listdir(sql_dir) - - for filename in fnmatch.filter(directory_entries, "*.sql"): - sql_loc = os.path.join(sql_dir, filename) - logger.debug("Applying schema %s", sql_loc) - executescript(cur, sql_loc) - - cur.execute( - database_engine.convert_param_style( - "INSERT INTO schema_version (version, upgraded)" - " VALUES (?,?)" - ), - (max_current_ver, False,) - ) - - _upgrade_existing_database( - cur, - current_version=max_current_ver, - applied_delta_files=[], - upgraded=False, - database_engine=database_engine, - ) - - -def _upgrade_existing_database(cur, current_version, applied_delta_files, - upgraded, database_engine): - """Upgrades an existing database. - - Delta files can either be SQL stored in *.sql files, or python modules - in *.py. - - There can be multiple delta files per version. Synapse will keep track of - which delta files have been applied, and will apply any that haven't been - even if there has been no version bump. This is useful for development - where orthogonal schema changes may happen on separate branches. - - Different delta files for the same version *must* be orthogonal and give - the same result when applied in any order. No guarantees are made on the - order of execution of these scripts. - - This is a no-op of current_version == SCHEMA_VERSION. - - Example directory structure: - - schema/ - delta/ - 11/ - foo.sql - ... - 12/ - foo.sql - bar.py - ... - full_schemas/ - ... - - In the example, if current_version is 11, then foo.sql will be run if and - only if `upgraded` is True. Then `foo.sql` and `bar.py` would be run in - some arbitrary order. - - Args: - cur (Cursor) - current_version (int): The current version of the schema. - applied_delta_files (list): A list of deltas that have already been - applied. - upgraded (bool): Whether the current version was generated by having - applied deltas or from full schema file. If `True` the function - will never apply delta files for the given `current_version`, since - the current_version wasn't generated by applying those delta files. - """ - - if current_version > SCHEMA_VERSION: - raise ValueError( - "Cannot use this database as it is too " + - "new for the server to understand" - ) - - start_ver = current_version - if not upgraded: - start_ver += 1 - - logger.debug("applied_delta_files: %s", applied_delta_files) - - for v in range(start_ver, SCHEMA_VERSION + 1): - logger.debug("Upgrading schema to v%d", v) - - delta_dir = os.path.join(dir_path, "schema", "delta", str(v)) - - try: - directory_entries = os.listdir(delta_dir) - except OSError: - logger.exception("Could not open delta dir for version %d", v) - raise UpgradeDatabaseException( - "Could not open delta dir for version %d" % (v,) - ) - - directory_entries.sort() - for file_name in directory_entries: - relative_path = os.path.join(str(v), file_name) - logger.debug("Found file: %s", relative_path) - if relative_path in applied_delta_files: - continue - - absolute_path = os.path.join( - dir_path, "schema", "delta", relative_path, - ) - root_name, ext = os.path.splitext(file_name) - if ext == ".py": - # This is a python upgrade module. We need to import into some - # package and then execute its `run_upgrade` function. - module_name = "synapse.storage.v%d_%s" % ( - v, root_name - ) - with open(absolute_path) as python_file: - module = imp.load_source( - module_name, absolute_path, python_file - ) - logger.debug("Running script %s", relative_path) - module.run_upgrade(cur, database_engine) - elif ext == ".pyc": - # Sometimes .pyc files turn up anyway even though we've - # disabled their generation; e.g. from distribution package - # installers. Silently skip it - pass - elif ext == ".sql": - # A plain old .sql file, just read and execute it - logger.debug("Applying schema %s", relative_path) - executescript(cur, absolute_path) - else: - # Not a valid delta file. - logger.warn( - "Found directory entry that did not end in .py or" - " .sql: %s", - relative_path, - ) - continue - - # Mark as done. - cur.execute( - database_engine.convert_param_style( - "INSERT INTO applied_schema_deltas (version, file)" - " VALUES (?,?)", - ), - (v, relative_path) - ) - - cur.execute("DELETE FROM schema_version") - cur.execute( - database_engine.convert_param_style( - "INSERT INTO schema_version (version, upgraded)" - " VALUES (?,?)", - ), - (v, True) - ) - - -def get_statements(f): - statement_buffer = "" - in_comment = False # If we're in a /* ... */ style comment - - for line in f: - line = line.strip() - - if in_comment: - # Check if this line contains an end to the comment - comments = line.split("*/", 1) - if len(comments) == 1: - continue - line = comments[1] - in_comment = False - - # Remove inline block comments - line = re.sub(r"/\*.*\*/", " ", line) - - # Does this line start a comment? - comments = line.split("/*", 1) - if len(comments) > 1: - line = comments[0] - in_comment = True - - # Deal with line comments - line = line.split("--", 1)[0] - line = line.split("//", 1)[0] - - # Find *all* semicolons. We need to treat first and last entry - # specially. - statements = line.split(";") - - # We must prepend statement_buffer to the first statement - first_statement = "%s %s" % ( - statement_buffer.strip(), - statements[0].strip() - ) - statements[0] = first_statement - - # Every entry, except the last, is a full statement - for statement in statements[:-1]: - yield statement.strip() - - # The last entry did *not* end in a semicolon, so we store it for the - # next semicolon we find - statement_buffer = statements[-1].strip() - - -def executescript(txn, schema_path): - with open(schema_path, 'r') as f: - for statement in get_statements(f): - txn.execute(statement) - - -def _get_or_create_schema_state(txn, database_engine): - # Bluntly try creating the schema_version tables. - schema_path = os.path.join( - dir_path, "schema", "schema_version.sql", - ) - executescript(txn, schema_path) - - txn.execute("SELECT version, upgraded FROM schema_version") - row = txn.fetchone() - current_version = int(row[0]) if row else None - upgraded = bool(row[1]) if row else None - - if current_version: - txn.execute( - database_engine.convert_param_style( - "SELECT file FROM applied_schema_deltas WHERE version >= ?" - ), - (current_version,) - ) - applied_deltas = [d for d, in txn.fetchall()] - return current_version, applied_deltas, upgraded - - return None - - -def prepare_sqlite3_database(db_conn): - """This function should be called before `prepare_database` on sqlite3 - databases. - - Since we changed the way we store the current schema version and handle - updates to schemas, we need a way to upgrade from the old method to the - new. This only affects sqlite databases since they were the only ones - supported at the time. - """ - with db_conn: - schema_path = os.path.join( - dir_path, "schema", "schema_version.sql", - ) - create_schema = read_schema(schema_path) - db_conn.executescript(create_schema) - - c = db_conn.execute("SELECT * FROM schema_version") - rows = c.fetchall() - c.close() - - if not rows: - c = db_conn.execute("PRAGMA user_version") - row = c.fetchone() - c.close() - - if row and row[0]: - db_conn.execute( - "REPLACE INTO schema_version (version, upgraded)" - " VALUES (?,?)", - (row[0], False) - ) diff --git a/synapse/storage/engines/postgres.py b/synapse/storage/engines/postgres.py index 949396044e..7e45dabf4c 100644 --- a/synapse/storage/engines/postgres.py +++ b/synapse/storage/engines/postgres.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.storage._schema_prepare import prepare_database +from synapse.storage.schema_prepare import prepare_database from ._base import IncorrectDatabaseSetup diff --git a/synapse/storage/engines/sqlite3.py b/synapse/storage/engines/sqlite3.py index a66815ef2d..0eeaa45d19 100644 --- a/synapse/storage/engines/sqlite3.py +++ b/synapse/storage/engines/sqlite3.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.storage._schema_prepare import ( +from synapse.storage.schema_prepare import ( prepare_database, prepare_sqlite3_database ) diff --git a/synapse/storage/schema_prepare.py b/synapse/storage/schema_prepare.py new file mode 100644 index 0000000000..1ddf55be4d --- /dev/null +++ b/synapse/storage/schema_prepare.py @@ -0,0 +1,395 @@ +# -*- coding: utf-8 -*- +# Copyright 2014, 2015 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import fnmatch +import imp +import logging +import os +import re + + +logger = logging.getLogger(__name__) + + +# Remember to update this number every time a change is made to database +# schema files, so the users will be informed on server restarts. +SCHEMA_VERSION = 24 + +dir_path = os.path.abspath(os.path.dirname(__file__)) + + +def read_schema(path): + """ Read the named database schema. + + Args: + path: Path of the database schema. + Returns: + A string containing the database schema. + """ + with open(path) as schema_file: + return schema_file.read() + + +class PrepareDatabaseException(Exception): + pass + + +class UpgradeDatabaseException(PrepareDatabaseException): + pass + + +def prepare_database(db_conn, database_engine): + """Prepares a database for usage. Will either create all necessary tables + or upgrade from an older schema version. + """ + try: + cur = db_conn.cursor() + version_info = _get_or_create_schema_state(cur, database_engine) + + if version_info: + user_version, delta_files, upgraded = version_info + _upgrade_existing_database( + cur, user_version, delta_files, upgraded, database_engine + ) + else: + _setup_new_database(cur, database_engine) + + # cur.execute("PRAGMA user_version = %d" % (SCHEMA_VERSION,)) + + cur.close() + db_conn.commit() + except: + db_conn.rollback() + raise + + +def _setup_new_database(cur, database_engine): + """Sets up the database by finding a base set of "full schemas" and then + applying any necessary deltas. + + The "full_schemas" directory has subdirectories named after versions. This + function searches for the highest version less than or equal to + `SCHEMA_VERSION` and executes all .sql files in that directory. + + The function will then apply all deltas for all versions after the base + version. + + Example directory structure: + + schema/ + delta/ + ... + full_schemas/ + 3/ + test.sql + ... + 11/ + foo.sql + bar.sql + ... + + In the example foo.sql and bar.sql would be run, and then any delta files + for versions strictly greater than 11. + """ + current_dir = os.path.join(dir_path, "schema", "full_schemas") + directory_entries = os.listdir(current_dir) + + valid_dirs = [] + pattern = re.compile(r"^\d+(\.sql)?$") + for filename in directory_entries: + match = pattern.match(filename) + abs_path = os.path.join(current_dir, filename) + if match and os.path.isdir(abs_path): + ver = int(match.group(0)) + if ver <= SCHEMA_VERSION: + valid_dirs.append((ver, abs_path)) + else: + logger.warn("Unexpected entry in 'full_schemas': %s", filename) + + if not valid_dirs: + raise PrepareDatabaseException( + "Could not find a suitable base set of full schemas" + ) + + max_current_ver, sql_dir = max(valid_dirs, key=lambda x: x[0]) + + logger.debug("Initialising schema v%d", max_current_ver) + + directory_entries = os.listdir(sql_dir) + + for filename in fnmatch.filter(directory_entries, "*.sql"): + sql_loc = os.path.join(sql_dir, filename) + logger.debug("Applying schema %s", sql_loc) + executescript(cur, sql_loc) + + cur.execute( + database_engine.convert_param_style( + "INSERT INTO schema_version (version, upgraded)" + " VALUES (?,?)" + ), + (max_current_ver, False,) + ) + + _upgrade_existing_database( + cur, + current_version=max_current_ver, + applied_delta_files=[], + upgraded=False, + database_engine=database_engine, + ) + + +def _upgrade_existing_database(cur, current_version, applied_delta_files, + upgraded, database_engine): + """Upgrades an existing database. + + Delta files can either be SQL stored in *.sql files, or python modules + in *.py. + + There can be multiple delta files per version. Synapse will keep track of + which delta files have been applied, and will apply any that haven't been + even if there has been no version bump. This is useful for development + where orthogonal schema changes may happen on separate branches. + + Different delta files for the same version *must* be orthogonal and give + the same result when applied in any order. No guarantees are made on the + order of execution of these scripts. + + This is a no-op of current_version == SCHEMA_VERSION. + + Example directory structure: + + schema/ + delta/ + 11/ + foo.sql + ... + 12/ + foo.sql + bar.py + ... + full_schemas/ + ... + + In the example, if current_version is 11, then foo.sql will be run if and + only if `upgraded` is True. Then `foo.sql` and `bar.py` would be run in + some arbitrary order. + + Args: + cur (Cursor) + current_version (int): The current version of the schema. + applied_delta_files (list): A list of deltas that have already been + applied. + upgraded (bool): Whether the current version was generated by having + applied deltas or from full schema file. If `True` the function + will never apply delta files for the given `current_version`, since + the current_version wasn't generated by applying those delta files. + """ + + if current_version > SCHEMA_VERSION: + raise ValueError( + "Cannot use this database as it is too " + + "new for the server to understand" + ) + + start_ver = current_version + if not upgraded: + start_ver += 1 + + logger.debug("applied_delta_files: %s", applied_delta_files) + + for v in range(start_ver, SCHEMA_VERSION + 1): + logger.debug("Upgrading schema to v%d", v) + + delta_dir = os.path.join(dir_path, "schema", "delta", str(v)) + + try: + directory_entries = os.listdir(delta_dir) + except OSError: + logger.exception("Could not open delta dir for version %d", v) + raise UpgradeDatabaseException( + "Could not open delta dir for version %d" % (v,) + ) + + directory_entries.sort() + for file_name in directory_entries: + relative_path = os.path.join(str(v), file_name) + logger.debug("Found file: %s", relative_path) + if relative_path in applied_delta_files: + continue + + absolute_path = os.path.join( + dir_path, "schema", "delta", relative_path, + ) + root_name, ext = os.path.splitext(file_name) + if ext == ".py": + # This is a python upgrade module. We need to import into some + # package and then execute its `run_upgrade` function. + module_name = "synapse.storage.v%d_%s" % ( + v, root_name + ) + with open(absolute_path) as python_file: + module = imp.load_source( + module_name, absolute_path, python_file + ) + logger.debug("Running script %s", relative_path) + module.run_upgrade(cur, database_engine) + elif ext == ".pyc": + # Sometimes .pyc files turn up anyway even though we've + # disabled their generation; e.g. from distribution package + # installers. Silently skip it + pass + elif ext == ".sql": + # A plain old .sql file, just read and execute it + logger.debug("Applying schema %s", relative_path) + executescript(cur, absolute_path) + else: + # Not a valid delta file. + logger.warn( + "Found directory entry that did not end in .py or" + " .sql: %s", + relative_path, + ) + continue + + # Mark as done. + cur.execute( + database_engine.convert_param_style( + "INSERT INTO applied_schema_deltas (version, file)" + " VALUES (?,?)", + ), + (v, relative_path) + ) + + cur.execute("DELETE FROM schema_version") + cur.execute( + database_engine.convert_param_style( + "INSERT INTO schema_version (version, upgraded)" + " VALUES (?,?)", + ), + (v, True) + ) + + +def get_statements(f): + statement_buffer = "" + in_comment = False # If we're in a /* ... */ style comment + + for line in f: + line = line.strip() + + if in_comment: + # Check if this line contains an end to the comment + comments = line.split("*/", 1) + if len(comments) == 1: + continue + line = comments[1] + in_comment = False + + # Remove inline block comments + line = re.sub(r"/\*.*\*/", " ", line) + + # Does this line start a comment? + comments = line.split("/*", 1) + if len(comments) > 1: + line = comments[0] + in_comment = True + + # Deal with line comments + line = line.split("--", 1)[0] + line = line.split("//", 1)[0] + + # Find *all* semicolons. We need to treat first and last entry + # specially. + statements = line.split(";") + + # We must prepend statement_buffer to the first statement + first_statement = "%s %s" % ( + statement_buffer.strip(), + statements[0].strip() + ) + statements[0] = first_statement + + # Every entry, except the last, is a full statement + for statement in statements[:-1]: + yield statement.strip() + + # The last entry did *not* end in a semicolon, so we store it for the + # next semicolon we find + statement_buffer = statements[-1].strip() + + +def executescript(txn, schema_path): + with open(schema_path, 'r') as f: + for statement in get_statements(f): + txn.execute(statement) + + +def _get_or_create_schema_state(txn, database_engine): + # Bluntly try creating the schema_version tables. + schema_path = os.path.join( + dir_path, "schema", "schema_version.sql", + ) + executescript(txn, schema_path) + + txn.execute("SELECT version, upgraded FROM schema_version") + row = txn.fetchone() + current_version = int(row[0]) if row else None + upgraded = bool(row[1]) if row else None + + if current_version: + txn.execute( + database_engine.convert_param_style( + "SELECT file FROM applied_schema_deltas WHERE version >= ?" + ), + (current_version,) + ) + applied_deltas = [d for d, in txn.fetchall()] + return current_version, applied_deltas, upgraded + + return None + + +def prepare_sqlite3_database(db_conn): + """This function should be called before `prepare_database` on sqlite3 + databases. + + Since we changed the way we store the current schema version and handle + updates to schemas, we need a way to upgrade from the old method to the + new. This only affects sqlite databases since they were the only ones + supported at the time. + """ + with db_conn: + schema_path = os.path.join( + dir_path, "schema", "schema_version.sql", + ) + create_schema = read_schema(schema_path) + db_conn.executescript(create_schema) + + c = db_conn.execute("SELECT * FROM schema_version") + rows = c.fetchall() + c.close() + + if not rows: + c = db_conn.execute("PRAGMA user_version") + row = c.fetchone() + c.close() + + if row and row[0]: + db_conn.execute( + "REPLACE INTO schema_version (version, upgraded)" + " VALUES (?,?)", + (row[0], False) + ) diff --git a/tests/utils.py b/tests/utils.py index dd19a16fc7..6eb575bd09 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -16,7 +16,7 @@ from synapse.http.server import HttpServer from synapse.api.errors import cs_error, CodeMessageException, StoreError from synapse.api.constants import EventTypes -from synapse.storage import prepare_database +from synapse.storage.schema_prepare import prepare_database from synapse.storage.engines import create_engine from synapse.server import HomeServer -- cgit 1.4.1 From 892e70ec8404865b4b1e6cf4eef7a3ada7114149 Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Wed, 28 Oct 2015 16:06:57 +0000 Subject: Add APIs for adding and removing tags from rooms --- synapse/rest/client/v2_alpha/__init__.py | 2 + synapse/rest/client/v2_alpha/tags.py | 89 +++++++++++++++ synapse/storage/__init__.py | 2 + synapse/storage/schema/delta/25/tags.sql | 37 ++++++ synapse/storage/tags.py | 190 +++++++++++++++++++++++++++++++ 5 files changed, 320 insertions(+) create mode 100644 synapse/rest/client/v2_alpha/tags.py create mode 100644 synapse/storage/schema/delta/25/tags.sql create mode 100644 synapse/storage/tags.py (limited to 'synapse/storage/__init__.py') diff --git a/synapse/rest/client/v2_alpha/__init__.py b/synapse/rest/client/v2_alpha/__init__.py index 5831ff0e62..a108132346 100644 --- a/synapse/rest/client/v2_alpha/__init__.py +++ b/synapse/rest/client/v2_alpha/__init__.py @@ -22,6 +22,7 @@ from . import ( receipts, keys, tokenrefresh, + tags, ) from synapse.http.server import JsonResource @@ -44,3 +45,4 @@ class ClientV2AlphaRestResource(JsonResource): receipts.register_servlets(hs, client_resource) keys.register_servlets(hs, client_resource) tokenrefresh.register_servlets(hs, client_resource) + tags.register_servlets(hs, client_resource) diff --git a/synapse/rest/client/v2_alpha/tags.py b/synapse/rest/client/v2_alpha/tags.py new file mode 100644 index 0000000000..c4a670c5a7 --- /dev/null +++ b/synapse/rest/client/v2_alpha/tags.py @@ -0,0 +1,89 @@ +# -*- coding: utf-8 -*- +# Copyright 2015 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ._base import client_v2_pattern + +from synapse.http.servlet import RestServlet + +from twisted.internet import defer + +import logging + +logger = logging.getLogger(__name__) + + +class TagListServlet(RestServlet): + """ + GET /user/{user_id}/rooms/{room_id}/tags HTTP/1.1 + """ + PATTERN = client_v2_pattern( + "/user/(?P[^/]*)/rooms/(?P[^/]*)/tags" + ) + + def __init__(self, hs): + super(TagListServlet, self).__init__() + self.auth = hs.get_auth() + self.store = hs.get_datastore() + + @defer.inlineCallbacks + def on_GET(self, request, user_id, room_id): + auth_user, _ = yield self.auth.get_user_by_req(request) + if user_id != auth_user.to_string(): + raise AuthError(403, "Cannot get tags for other users.") + + tags = yield self.store.get_tags_for_room(user_id, room_id) + + defer.returnValue((200, {"tags": tags})) + + +class TagServlet(RestServlet): + """ + PUT /user/{user_id}/rooms/{room_id}/tags/{tag} HTTP/1.1 + DELETE /user/{user_id}/rooms/{room_id}/tags/{tag} HTTP/1.1 + """ + PATTERN = client_v2_pattern( + "/user/(?P[^/]*)/rooms/(?P[^/]*)/tags/(?P[^/]*)" + ) + def __init__(self, hs): + super(TagServlet, self).__init__() + self.auth = hs.get_auth() + self.store = hs.get_datastore() + + @defer.inlineCallbacks + def on_PUT(self, request, user_id, room_id, tag): + auth_user, _ = yield self.auth.get_user_by_req(request) + if user_id != auth_user.to_string(): + raise AuthError(403, "Cannot add tags for other users.") + + yield self.store.add_tag_to_room(user_id, room_id, tag) + + # TODO: poke the notifier. + defer.returnValue((200, {})) + + @defer.inlineCallbacks + def on_DELETE(self, request, user_id, room_id, tag): + auth_user, _ = yield self.auth.get_user_by_req(request) + if user_id != auth_user.to_string(): + raise AuthError(403, "Cannot add tags for other users.") + + yield self.store.remove_tag_from_room(user_id, room_id, tag) + + # TODO: poke the notifier. + defer.returnValue((200, {})) + + +def register_servlets(hs, http_server): + TagListServlet(hs).register(http_server) + TagServlet(hs).register(http_server) diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index a1bd9c4ce9..e7443f2838 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -41,6 +41,7 @@ from .end_to_end_keys import EndToEndKeyStore from .receipts import ReceiptsStore from .search import SearchStore +from .tags import TagsStore import logging @@ -71,6 +72,7 @@ class DataStore(RoomMemberStore, RoomStore, ReceiptsStore, EndToEndKeyStore, SearchStore, + TagsStore, ): def __init__(self, hs): diff --git a/synapse/storage/schema/delta/25/tags.sql b/synapse/storage/schema/delta/25/tags.sql new file mode 100644 index 0000000000..168766dcf3 --- /dev/null +++ b/synapse/storage/schema/delta/25/tags.sql @@ -0,0 +1,37 @@ +/* Copyright 2015 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +CREATE TABLE IF NOT EXISTS room_tags( + user_id TEXT NOT NULL, + room_id TEXT NOT NULL, + tag TEXT NOT NULL, -- The name of the tag. + CONSTRAINT room_tag_uniqueness UNIQUE (user_id, room_id, tag) +); + +CREATE TABLE IF NOT EXISTS room_tags_revisions ( + user_id TEXT NOT NULL, + room_id TEXT NOT NULL, + stream_id BIGINT NOT NULL, -- The current version of the room tags. + CONSTRAINT room_tag_revisions_uniqueness UNIQUE (user_id, room_id) +); + +CREATE TABLE IF NOT EXISTS private_user_data_max_stream_id( + Lock CHAR(1) NOT NULL DEFAULT 'X' UNIQUE, -- Makes sure this table only has one row. + stream_id BIGINT NOT NULL, + CHECK (Lock='X') +); + +INSERT INTO private_user_data_max_stream_id (stream_id) VALUES (0); diff --git a/synapse/storage/tags.py b/synapse/storage/tags.py new file mode 100644 index 0000000000..507e60596f --- /dev/null +++ b/synapse/storage/tags.py @@ -0,0 +1,190 @@ +# -*- coding: utf-8 -*- +# Copyright 2014, 2015 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ._base import SQLBaseStore +from synapse.util.caches.descriptors import cached +from twisted.internet import defer +from .util.id_generators import StreamIdGenerator + +import logging + +logger = logging.getLogger(__name__) + + +class TagsStore(SQLBaseStore): + def __init__(self, hs): + super(TagsStore, self).__init__(hs) + + self._private_user_data_id_gen = StreamIdGenerator( + "private_user_data_max_stream_id", "stream_id" + ) + + @cached() + def get_tags_for_user(self, user_id): + """Get all the tags for a user. + + + Args: + user_id(str): The user to get the tags for. + Returns: + A deferred dict mapping from room_id strings to lists of tag + strings. + """ + + deferred = self._simple_select_list( + "room_tags", {"user_id": user_id}, ["room_id", "tag"] + ) + + @deferred.addCallback + def tags_by_room(rows): + tags_by_room = {} + for row in rows: + tags_by_room.setdefault(row["room_id"], []).append(row["tag"]) + return tags_by_room + + return deferred + + @defer.inlineCallbacks + def get_updated_tags(self, user_id, stream_id): + """Get all the tags for the rooms where the tags have changed since the + given version + + Args: + user_id(str): The user to get the tags for. + stream_id(int): The earliest update to get for the user. + Returns: + A deferred dict mapping from room_id strings to lists of tag + strings for all the rooms that changed since the stream_id token. + """ + def get_updated_tags_txn(txn): + sql = ( + "SELECT room_id from room_tags_revisions" + " WHERE user_id = ? AND stream_id > ?" + ) + txn.execute(sql, (user_id, stream_id)) + room_ids = [row[0] for row in txn.fetchall()] + return room_ids + + room_ids = yield self.runInteraction( + "get_updated_tags", get_updated_tags_txn + ) + + results = {} + if room_ids: + tags_by_room = yield self.get_tags_for_user(self, user_id) + for room_id in rooms_ids: + results[room_id] = tags_by_room[room_id] + + defer.returnValue(results) + + def get_tags_for_room(self, user_id, room_id): + """Get all the tags for the given room + Args: + user_id(str): The user to get tags for + room_id(str): The room to get tags for + Returns: + A deferred list of string tags. + """ + return self._simple_select_onecol( + table="room_tags", + keyvalues={"user_id": user_id, "room_id": room_id}, + retcol="tag", + desc="get_tags_for_room", + ) + + @defer.inlineCallbacks + def add_tag_to_room(self, user_id, room_id, tag): + """Add a tag to a room for a user. + Returns: + A deferred that completes once the tag has been added. + """ + def add_tag_txn(txn, next_id): + sql = ( + "INSERT INTO room_tags (user_id, room_id, tag)" + " VALUES (?, ?, ?)" + ) + try: + txn.execute(sql, (user_id, room_id, tag)) + except database_engine.module.IntegrityError as e: + # Return early if the row is already in the table + # and we don't need to bump the revision number of the + # private_user_data. + return + self._update_revision_txn(txn, user_id, room_id, next_id) + + with (yield self._private_user_data_id_gen.get_next(self)) as next_id: + yield self.runInteraction("add_tag", add_tag_txn, next_id) + + self.get_tags_for_user.invalidate((user_id,)) + + @defer.inlineCallbacks + def remove_tag_from_room(self, user_id, room_id, tag): + """Remove a tag from a room for a user. + Returns: + A deffered that completes once the tag has been removed + """ + def remove_tag_txn(txn, next_id): + sql = ( + "DELETE FROM room_tags " + " WHERE user_id = ? AND room_id = ? AND tag = ?" + ) + txn.execute(sql, (user_id, room_id, tag)) + self._update_revision_txn(txn, user_id, room_id, next_id) + + with (yield self._private_user_data_id_gen.get_next(self)) as next_id: + yield self.runInteraction("remove_tag", remove_tag_txn, next_id) + + self.get_tags_for_user.invalidate((user_id,)) + + def _update_revision_txn(self, txn, user_id, room_id, next_id): + """Update the latest revision of the tags for the given user and room. + + Args: + txn: The database cursor + user_id(str): The ID of the user. + room_id(str): The ID of the room. + next_id(int): The the revision to advance to. + """ + + update_max_id_sql = ( + "UPDATE private_user_data_max_stream_id" + " SET stream_id = ?" + " WHERE stream_id < ?" + ) + txn.execute(update_max_id_sql, (next_id, next_id)) + + update_sql = ( + "UPDATE room_tags_revisions" + " SET stream_id = ?" + " WHERE user_id = ?" + " AND room_id = ?" + ) + txn.execute(update_sql, (next_id, user_id, room_id)) + + if txn.rowcount == 0: + insert_sql = ( + "INSERT INTO room_tags_revisions (user_id, room_id, stream_id)" + " VALUES (?, ?, ?)" + ) + try: + txn.execute(insert_sql, (user_id, room_id, next_id)) + except database_engine.module.IntegrityError as e: + # Ignore insertion errors. It doesn't matter if the row wasn't + # inserted because if two updates happend concurrently the one + # with the higher stream_id will not be reported to a client + # unless the previous update has completed. It doesn't matter + # which stream_id ends up in the table, as long as it is higher + # than the id that the client has. + pass -- cgit 1.4.1