diff options
author | Richard van der Hoff <richard@matrix.org> | 2016-03-03 19:05:54 +0000 |
---|---|---|
committer | Richard van der Hoff <richard@matrix.org> | 2016-03-03 19:05:54 +0000 |
commit | a85179aff3bf2bc1b132e9918cd8222a61a8bcc2 (patch) | |
tree | aac1ff9e5bdca7ddc0fdc9d2cae5a5dcf13e8733 /synapse | |
parent | Empty commit (diff) | |
parent | Merge pull request #621 from matrix-org/daniel/ratelimiting (diff) | |
download | synapse-a85179aff3bf2bc1b132e9918cd8222a61a8bcc2.tar.xz |
Merge remote-tracking branch 'origin/develop' into rav/SYN-642
Diffstat (limited to 'synapse')
-rwxr-xr-x | synapse/app/homeserver.py | 2 | ||||
-rw-r--r-- | synapse/handlers/_base.py | 15 | ||||
-rw-r--r-- | synapse/handlers/auth.py | 2 | ||||
-rw-r--r-- | synapse/handlers/directory.py | 20 | ||||
-rw-r--r-- | synapse/handlers/federation.py | 4 | ||||
-rw-r--r-- | synapse/handlers/message.py | 8 | ||||
-rw-r--r-- | synapse/handlers/profile.py | 17 | ||||
-rw-r--r-- | synapse/handlers/room.py | 76 | ||||
-rw-r--r-- | synapse/rest/client/v1/directory.py | 6 | ||||
-rw-r--r-- | synapse/rest/client/v1/profile.py | 4 | ||||
-rw-r--r-- | synapse/rest/client/v1/room.py | 8 | ||||
-rw-r--r-- | synapse/storage/appservice.py | 34 | ||||
-rw-r--r-- | synapse/storage/engines/__init__.py | 5 | ||||
-rw-r--r-- | synapse/storage/engines/postgres.py | 5 | ||||
-rw-r--r-- | synapse/storage/engines/sqlite3.py | 5 | ||||
-rw-r--r-- | synapse/storage/prepare_database.py | 13 | ||||
-rw-r--r-- | synapse/storage/schema/delta/30/as_users.py | 59 |
17 files changed, 193 insertions, 90 deletions
diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index de5ee988f1..021dc1d610 100755 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -386,7 +386,7 @@ def setup(config_options): tls_server_context_factory = context_factory.ServerContextFactory(config) - database_engine = create_engine(config.database_config["name"]) + database_engine = create_engine(config) config.database_config["args"]["cp_openfun"] = database_engine.on_new_connection hs = SynapseHomeServer( diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py index a6f890e0b6..c6a74b0e3d 100644 --- a/synapse/handlers/_base.py +++ b/synapse/handlers/_base.py @@ -160,10 +160,10 @@ class BaseHandler(object): ) defer.returnValue(res.get(user_id, [])) - def ratelimit(self, user_id): + def ratelimit(self, requester): time_now = self.clock.time() allowed, time_allowed = self.ratelimiter.send_message( - user_id, time_now, + requester.user.to_string(), time_now, msg_rate_hz=self.hs.config.rc_messages_per_second, burst_count=self.hs.config.rc_message_burst_count, ) @@ -282,11 +282,18 @@ class BaseHandler(object): return False @defer.inlineCallbacks - def handle_new_client_event(self, event, context, ratelimit=True, extra_users=[]): + def handle_new_client_event( + self, + requester, + event, + context, + ratelimit=True, + extra_users=[] + ): # We now need to go and hit out to wherever we need to hit out to. if ratelimit: - self.ratelimit(event.sender) + self.ratelimit(requester) self.auth.check(event, auth_events=context.current_state) diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 62e82a2570..7a4afe446d 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -477,4 +477,4 @@ class AuthHandler(BaseHandler): Returns: Whether self.hash(password) == stored_hash (bool). """ - return bcrypt.checkpw(password, stored_hash) + return bcrypt.hashpw(password, stored_hash) == stored_hash diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py index e0a778e7ff..88166f0187 100644 --- a/synapse/handlers/directory.py +++ b/synapse/handlers/directory.py @@ -212,17 +212,21 @@ class DirectoryHandler(BaseHandler): ) @defer.inlineCallbacks - def send_room_alias_update_event(self, user_id, room_id): + def send_room_alias_update_event(self, requester, user_id, room_id): aliases = yield self.store.get_aliases_for_room(room_id) msg_handler = self.hs.get_handlers().message_handler - yield msg_handler.create_and_send_nonmember_event({ - "type": EventTypes.Aliases, - "state_key": self.hs.hostname, - "room_id": room_id, - "sender": user_id, - "content": {"aliases": aliases}, - }, ratelimit=False) + yield msg_handler.create_and_send_nonmember_event( + requester, + { + "type": EventTypes.Aliases, + "state_key": self.hs.hostname, + "room_id": room_id, + "sender": user_id, + "content": {"aliases": aliases}, + }, + ratelimit=False + ) @defer.inlineCallbacks def get_association_from_room_alias(self, room_alias): diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 3655b9e5e2..6e50b0963e 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -1657,7 +1657,7 @@ class FederationHandler(BaseHandler): self.auth.check(event, context.current_state) yield self._check_signature(event, auth_events=context.current_state) member_handler = self.hs.get_handlers().room_member_handler - yield member_handler.send_membership_event(event, context, from_client=False) + yield member_handler.send_membership_event(None, event, context) else: destinations = set(x.split(":", 1)[-1] for x in (sender_user_id, room_id)) yield self.replication_layer.forward_third_party_invite( @@ -1686,7 +1686,7 @@ class FederationHandler(BaseHandler): # TODO: Make sure the signatures actually are correct. event.signatures.update(returned_invite.signatures) member_handler = self.hs.get_handlers().room_member_handler - yield member_handler.send_membership_event(event, context, from_client=False) + yield member_handler.send_membership_event(None, event, context) @defer.inlineCallbacks def add_display_name_to_third_party_invite(self, event_dict, event, context): diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index afa7c9c36c..cace1cb82a 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -215,7 +215,7 @@ class MessageHandler(BaseHandler): defer.returnValue((event, context)) @defer.inlineCallbacks - def send_nonmember_event(self, event, context, ratelimit=True): + def send_nonmember_event(self, requester, event, context, ratelimit=True): """ Persists and notifies local clients and federation of an event. @@ -241,6 +241,7 @@ class MessageHandler(BaseHandler): defer.returnValue(prev_state) yield self.handle_new_client_event( + requester=requester, event=event, context=context, ratelimit=ratelimit, @@ -268,9 +269,9 @@ class MessageHandler(BaseHandler): @defer.inlineCallbacks def create_and_send_nonmember_event( self, + requester, event_dict, ratelimit=True, - token_id=None, txn_id=None ): """ @@ -280,10 +281,11 @@ class MessageHandler(BaseHandler): """ event, context = yield self.create_event( event_dict, - token_id=token_id, + token_id=requester.access_token_id, txn_id=txn_id ) yield self.send_nonmember_event( + requester, event, context, ratelimit=ratelimit, diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py index c9ad5944e6..b45eafbb49 100644 --- a/synapse/handlers/profile.py +++ b/synapse/handlers/profile.py @@ -89,13 +89,13 @@ class ProfileHandler(BaseHandler): defer.returnValue(result["displayname"]) @defer.inlineCallbacks - def set_displayname(self, target_user, auth_user, new_displayname): + def set_displayname(self, target_user, requester, new_displayname): """target_user is the user whose displayname is to be changed; auth_user is the user attempting to make this change.""" if not self.hs.is_mine(target_user): raise SynapseError(400, "User is not hosted on this Home Server") - if target_user != auth_user: + if target_user != requester.user: raise AuthError(400, "Cannot set another user's displayname") if new_displayname == '': @@ -109,7 +109,7 @@ class ProfileHandler(BaseHandler): "displayname": new_displayname, }) - yield self._update_join_states(target_user) + yield self._update_join_states(requester) @defer.inlineCallbacks def get_avatar_url(self, target_user): @@ -139,13 +139,13 @@ class ProfileHandler(BaseHandler): defer.returnValue(result["avatar_url"]) @defer.inlineCallbacks - def set_avatar_url(self, target_user, auth_user, new_avatar_url): + def set_avatar_url(self, target_user, requester, new_avatar_url): """target_user is the user whose avatar_url is to be changed; auth_user is the user attempting to make this change.""" if not self.hs.is_mine(target_user): raise SynapseError(400, "User is not hosted on this Home Server") - if target_user != auth_user: + if target_user != requester.user: raise AuthError(400, "Cannot set another user's avatar_url") yield self.store.set_profile_avatar_url( @@ -156,7 +156,7 @@ class ProfileHandler(BaseHandler): "avatar_url": new_avatar_url, }) - yield self._update_join_states(target_user) + yield self._update_join_states(requester) @defer.inlineCallbacks def collect_presencelike_data(self, user, state): @@ -199,11 +199,12 @@ class ProfileHandler(BaseHandler): defer.returnValue(response) @defer.inlineCallbacks - def _update_join_states(self, user): + def _update_join_states(self, requester): + user = requester.user if not self.hs.is_mine(user): return - self.ratelimit(user.to_string()) + self.ratelimit(requester) joins = yield self.store.get_rooms_for_user( user.to_string(), diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index ad7c83f477..6dd7a41f04 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -18,7 +18,7 @@ from twisted.internet import defer from ._base import BaseHandler -from synapse.types import UserID, RoomAlias, RoomID, RoomStreamToken +from synapse.types import UserID, RoomAlias, RoomID, RoomStreamToken, Requester from synapse.api.constants import ( EventTypes, Membership, JoinRules, RoomCreationPreset, ) @@ -90,7 +90,7 @@ class RoomCreationHandler(BaseHandler): """ user_id = requester.user.to_string() - self.ratelimit(user_id) + self.ratelimit(requester) if "room_alias_name" in config: for wchar in string.whitespace: @@ -185,23 +185,29 @@ class RoomCreationHandler(BaseHandler): if "name" in config: name = config["name"] - yield msg_handler.create_and_send_nonmember_event({ - "type": EventTypes.Name, - "room_id": room_id, - "sender": user_id, - "state_key": "", - "content": {"name": name}, - }, ratelimit=False) + yield msg_handler.create_and_send_nonmember_event( + requester, + { + "type": EventTypes.Name, + "room_id": room_id, + "sender": user_id, + "state_key": "", + "content": {"name": name}, + }, + ratelimit=False) if "topic" in config: topic = config["topic"] - yield msg_handler.create_and_send_nonmember_event({ - "type": EventTypes.Topic, - "room_id": room_id, - "sender": user_id, - "state_key": "", - "content": {"topic": topic}, - }, ratelimit=False) + yield msg_handler.create_and_send_nonmember_event( + requester, + { + "type": EventTypes.Topic, + "room_id": room_id, + "sender": user_id, + "state_key": "", + "content": {"topic": topic}, + }, + ratelimit=False) for invitee in invite_list: room_member_handler.update_membership( @@ -231,7 +237,7 @@ class RoomCreationHandler(BaseHandler): if room_alias: result["room_alias"] = room_alias.to_string() yield directory_handler.send_room_alias_update_event( - user_id, room_id + requester, user_id, room_id ) defer.returnValue(result) @@ -263,7 +269,11 @@ class RoomCreationHandler(BaseHandler): @defer.inlineCallbacks def send(etype, content, **kwargs): event = create(etype, content, **kwargs) - yield msg_handler.create_and_send_nonmember_event(event, ratelimit=False) + yield msg_handler.create_and_send_nonmember_event( + creator, + event, + ratelimit=False + ) config = RoomCreationHandler.PRESETS_DICT[preset_config] @@ -454,12 +464,11 @@ class RoomMemberHandler(BaseHandler): member_handler = self.hs.get_handlers().room_member_handler yield member_handler.send_membership_event( + requester, event, context, - is_guest=requester.is_guest, ratelimit=ratelimit, remote_room_hosts=remote_room_hosts, - from_client=True, ) if action == "forget": @@ -468,17 +477,19 @@ class RoomMemberHandler(BaseHandler): @defer.inlineCallbacks def send_membership_event( self, + requester, event, context, - is_guest=False, remote_room_hosts=None, ratelimit=True, - from_client=True, ): """ Change the membership status of a user in a room. Args: + requester (Requester): The local user who requested the membership + event. If None, certain checks, like whether this homeserver can + act as the sender, will be skipped. event (SynapseEvent): The membership event. context: The context of the event. is_guest (bool): Whether the sender is a guest. @@ -486,10 +497,6 @@ class RoomMemberHandler(BaseHandler): the room, and could be danced with in order to join this homeserver for the first time. ratelimit (bool): Whether to rate limit this request. - from_client (bool): Whether this request is the result of a local - client request (rather than over federation). If so, we will - perform extra checks, like that this homeserver can act as this - client. Raises: SynapseError if there was a problem changing the membership. """ @@ -498,9 +505,15 @@ class RoomMemberHandler(BaseHandler): target_user = UserID.from_string(event.state_key) room_id = event.room_id - if from_client: + if requester is not None: sender = UserID.from_string(event.sender) + assert sender == requester.user, ( + "Sender (%s) must be same as requester (%s)" % + (sender, requester.user) + ) assert self.hs.is_mine(sender), "Sender must be our own: %s" % (sender,) + else: + requester = Requester(target_user, None, False) message_handler = self.hs.get_handlers().message_handler prev_event = message_handler.deduplicate_state_event(event, context) @@ -510,7 +523,7 @@ class RoomMemberHandler(BaseHandler): action = "send" if event.membership == Membership.JOIN: - if is_guest and not self._can_guest_join(context.current_state): + if requester.is_guest and not self._can_guest_join(context.current_state): # This should be an auth check, but guests are a local concept, # so don't really fit into the general auth process. raise AuthError(403, "Guest access not allowed") @@ -566,6 +579,7 @@ class RoomMemberHandler(BaseHandler): ) else: yield self.handle_new_client_event( + requester, event, context, extra_users=[target_user], @@ -684,12 +698,12 @@ class RoomMemberHandler(BaseHandler): ) else: yield self._make_and_store_3pid_invite( + requester, id_server, medium, address, room_id, inviter, - requester.access_token_id, txn_id=txn_id ) @@ -747,12 +761,12 @@ class RoomMemberHandler(BaseHandler): @defer.inlineCallbacks def _make_and_store_3pid_invite( self, + requester, id_server, medium, address, room_id, user, - token_id, txn_id ): room_state = yield self.hs.get_state_handler().get_current_state(room_id) @@ -802,6 +816,7 @@ class RoomMemberHandler(BaseHandler): msg_handler = self.hs.get_handlers().message_handler yield msg_handler.create_and_send_nonmember_event( + requester, { "type": EventTypes.ThirdPartyInvite, "content": { @@ -816,7 +831,6 @@ class RoomMemberHandler(BaseHandler): "sender": user.to_string(), "state_key": token, }, - token_id=token_id, txn_id=txn_id, ) diff --git a/synapse/rest/client/v1/directory.py b/synapse/rest/client/v1/directory.py index 74ec1e50e0..8c1a2614a0 100644 --- a/synapse/rest/client/v1/directory.py +++ b/synapse/rest/client/v1/directory.py @@ -75,7 +75,11 @@ class ClientDirectoryServer(ClientV1RestServlet): yield dir_handler.create_association( user_id, room_alias, room_id, servers ) - yield dir_handler.send_room_alias_update_event(user_id, room_id) + yield dir_handler.send_room_alias_update_event( + requester, + user_id, + room_id + ) except SynapseError as e: raise e except: diff --git a/synapse/rest/client/v1/profile.py b/synapse/rest/client/v1/profile.py index 3c5a212920..953764bd8e 100644 --- a/synapse/rest/client/v1/profile.py +++ b/synapse/rest/client/v1/profile.py @@ -51,7 +51,7 @@ class ProfileDisplaynameRestServlet(ClientV1RestServlet): defer.returnValue((400, "Unable to parse name")) yield self.handlers.profile_handler.set_displayname( - user, requester.user, new_name) + user, requester, new_name) defer.returnValue((200, {})) @@ -88,7 +88,7 @@ class ProfileAvatarURLRestServlet(ClientV1RestServlet): defer.returnValue((400, "Unable to parse name")) yield self.handlers.profile_handler.set_avatar_url( - user, requester.user, new_name) + user, requester, new_name) defer.returnValue((200, {})) diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index f5ed4f7302..cbf3673eff 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -158,12 +158,12 @@ class RoomStateEventRestServlet(ClientV1RestServlet): if event_type == EventTypes.Member: yield self.handlers.room_member_handler.send_membership_event( + requester, event, context, - is_guest=requester.is_guest, ) else: - yield msg_handler.send_nonmember_event(event, context) + yield msg_handler.send_nonmember_event(requester, event, context) defer.returnValue((200, {"event_id": event.event_id})) @@ -183,13 +183,13 @@ class RoomSendEventRestServlet(ClientV1RestServlet): msg_handler = self.handlers.message_handler event = yield msg_handler.create_and_send_nonmember_event( + requester, { "type": event_type, "content": content, "room_id": room_id, "sender": requester.user.to_string(), }, - token_id=requester.access_token_id, txn_id=txn_id, ) @@ -504,6 +504,7 @@ class RoomRedactEventRestServlet(ClientV1RestServlet): msg_handler = self.handlers.message_handler event = yield msg_handler.create_and_send_nonmember_event( + requester, { "type": EventTypes.Redaction, "content": content, @@ -511,7 +512,6 @@ class RoomRedactEventRestServlet(ClientV1RestServlet): "sender": requester.user.to_string(), "redacts": event_id, }, - token_id=requester.access_token_id, txn_id=txn_id, ) diff --git a/synapse/storage/appservice.py b/synapse/storage/appservice.py index 1100c67714..371600eebb 100644 --- a/synapse/storage/appservice.py +++ b/synapse/storage/appservice.py @@ -34,8 +34,8 @@ class ApplicationServiceStore(SQLBaseStore): def __init__(self, hs): super(ApplicationServiceStore, self).__init__(hs) self.hostname = hs.hostname - self.services_cache = [] - self._populate_appservice_cache( + self.services_cache = ApplicationServiceStore.load_appservices( + hs.hostname, hs.config.app_service_config_files ) @@ -144,21 +144,23 @@ class ApplicationServiceStore(SQLBaseStore): return rooms_for_user_matching_user_id - def _load_appservice(self, as_info): + @classmethod + def _load_appservice(cls, hostname, as_info, config_filename): required_string_fields = [ - # TODO: Add id here when it's stable to release - "url", "as_token", "hs_token", "sender_localpart" + "id", "url", "as_token", "hs_token", "sender_localpart" ] for field in required_string_fields: if not isinstance(as_info.get(field), basestring): - raise KeyError("Required string field: '%s'", field) + raise KeyError("Required string field: '%s' (%s)" % ( + field, config_filename, + )) localpart = as_info["sender_localpart"] if urllib.quote(localpart) != localpart: raise ValueError( "sender_localpart needs characters which are not URL encoded." ) - user = UserID(localpart, self.hostname) + user = UserID(localpart, hostname) user_id = user.to_string() # namespace checks @@ -188,25 +190,30 @@ class ApplicationServiceStore(SQLBaseStore): namespaces=as_info["namespaces"], hs_token=as_info["hs_token"], sender=user_id, - id=as_info["id"] if "id" in as_info else as_info["as_token"], + id=as_info["id"], ) - def _populate_appservice_cache(self, config_files): - """Populates a cache of Application Services from the config files.""" + @classmethod + def load_appservices(cls, hostname, config_files): + """Returns a list of Application Services from the config files.""" if not isinstance(config_files, list): logger.warning( "Expected %s to be a list of AS config files.", config_files ) - return + return [] # Dicts of value -> filename seen_as_tokens = {} seen_ids = {} + appservices = [] + for config_file in config_files: try: with open(config_file, 'r') as f: - appservice = self._load_appservice(yaml.load(f)) + appservice = ApplicationServiceStore._load_appservice( + hostname, yaml.load(f), config_file + ) if appservice.id in seen_ids: raise ConfigError( "Cannot reuse ID across application services: " @@ -226,11 +233,12 @@ class ApplicationServiceStore(SQLBaseStore): ) seen_as_tokens[appservice.token] = config_file logger.info("Loaded application service: %s", appservice) - self.services_cache.append(appservice) + appservices.append(appservice) except Exception as e: logger.error("Failed to load appservice from '%s'", config_file) logger.exception(e) raise + return appservices class ApplicationServiceTransactionStore(SQLBaseStore): diff --git a/synapse/storage/engines/__init__.py b/synapse/storage/engines/__init__.py index 4290aea83a..a48230b93f 100644 --- a/synapse/storage/engines/__init__.py +++ b/synapse/storage/engines/__init__.py @@ -26,12 +26,13 @@ SUPPORTED_MODULE = { } -def create_engine(name): +def create_engine(config): + name = config.database_config["name"] engine_class = SUPPORTED_MODULE.get(name, None) if engine_class: module = importlib.import_module(name) - return engine_class(module) + return engine_class(module, config=config) raise RuntimeError( "Unsupported database engine '%s'" % (name,) diff --git a/synapse/storage/engines/postgres.py b/synapse/storage/engines/postgres.py index 17b7a9c077..a09685b4df 100644 --- a/synapse/storage/engines/postgres.py +++ b/synapse/storage/engines/postgres.py @@ -21,9 +21,10 @@ from ._base import IncorrectDatabaseSetup class PostgresEngine(object): single_threaded = False - def __init__(self, database_module): + def __init__(self, database_module, config): self.module = database_module self.module.extensions.register_type(self.module.extensions.UNICODE) + self.config = config def check_database(self, txn): txn.execute("SHOW SERVER_ENCODING") @@ -44,7 +45,7 @@ class PostgresEngine(object): ) def prepare_database(self, db_conn): - prepare_database(db_conn, self) + prepare_database(db_conn, self, config=self.config) def is_deadlock(self, error): if isinstance(error, self.module.DatabaseError): diff --git a/synapse/storage/engines/sqlite3.py b/synapse/storage/engines/sqlite3.py index 91fac33b8b..522b905949 100644 --- a/synapse/storage/engines/sqlite3.py +++ b/synapse/storage/engines/sqlite3.py @@ -23,8 +23,9 @@ import struct class Sqlite3Engine(object): single_threaded = True - def __init__(self, database_module): + def __init__(self, database_module, config): self.module = database_module + self.config = config def check_database(self, txn): pass @@ -38,7 +39,7 @@ class Sqlite3Engine(object): def prepare_database(self, db_conn): prepare_sqlite3_database(db_conn) - prepare_database(db_conn, self) + prepare_database(db_conn, self, config=self.config) def is_deadlock(self, error): return False diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py index 0fd5d497ab..3f29aad1e8 100644 --- a/synapse/storage/prepare_database.py +++ b/synapse/storage/prepare_database.py @@ -50,7 +50,7 @@ class UpgradeDatabaseException(PrepareDatabaseException): pass -def prepare_database(db_conn, database_engine): +def prepare_database(db_conn, database_engine, config): """Prepares a database for usage. Will either create all necessary tables or upgrade from an older schema version. """ @@ -61,10 +61,10 @@ def prepare_database(db_conn, database_engine): if version_info: user_version, delta_files, upgraded = version_info _upgrade_existing_database( - cur, user_version, delta_files, upgraded, database_engine + cur, user_version, delta_files, upgraded, database_engine, config ) else: - _setup_new_database(cur, database_engine) + _setup_new_database(cur, database_engine, config) # cur.execute("PRAGMA user_version = %d" % (SCHEMA_VERSION,)) @@ -75,7 +75,7 @@ def prepare_database(db_conn, database_engine): raise -def _setup_new_database(cur, database_engine): +def _setup_new_database(cur, database_engine, config): """Sets up the database by finding a base set of "full schemas" and then applying any necessary deltas. @@ -148,11 +148,12 @@ def _setup_new_database(cur, database_engine): applied_delta_files=[], upgraded=False, database_engine=database_engine, + config=config, ) def _upgrade_existing_database(cur, current_version, applied_delta_files, - upgraded, database_engine): + upgraded, database_engine, config): """Upgrades an existing database. Delta files can either be SQL stored in *.sql files, or python modules @@ -245,7 +246,7 @@ def _upgrade_existing_database(cur, current_version, applied_delta_files, module_name, absolute_path, python_file ) logger.debug("Running script %s", relative_path) - module.run_upgrade(cur, database_engine) + module.run_upgrade(cur, database_engine, config=config) elif ext == ".pyc": # Sometimes .pyc files turn up anyway even though we've # disabled their generation; e.g. from distribution package diff --git a/synapse/storage/schema/delta/30/as_users.py b/synapse/storage/schema/delta/30/as_users.py new file mode 100644 index 0000000000..4cf4dd0917 --- /dev/null +++ b/synapse/storage/schema/delta/30/as_users.py @@ -0,0 +1,59 @@ +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +from synapse.storage.appservice import ApplicationServiceStore + + +logger = logging.getLogger(__name__) + + +def run_upgrade(cur, database_engine, config, *args, **kwargs): + # NULL indicates user was not registered by an appservice. + cur.execute("ALTER TABLE users ADD COLUMN appservice_id TEXT") + + cur.execute("SELECT name FROM users") + rows = cur.fetchall() + + config_files = [] + try: + config_files = config.app_service_config_files + except AttributeError: + logger.warning("Could not get app_service_config_files from config") + pass + + appservices = ApplicationServiceStore.load_appservices( + config.server_name, config_files + ) + + owned = {} + + for row in rows: + user_id = row[0] + for appservice in appservices: + if appservice.is_exclusive_user(user_id): + if user_id in owned.keys(): + logger.error( + "user_id %s was owned by more than one application" + " service (IDs %s and %s); assigning arbitrarily to %s" % + (user_id, owned[user_id], appservice.id, owned[user_id]) + ) + owned[user_id] = appservice.id + + for user_id, as_id in owned.items(): + cur.execute( + database_engine.convert_param_style( + "UPDATE users SET appservice_id = ? WHERE name = ?" + ), + (as_id, user_id) + ) |