diff --git a/SecretsManagerRDSMySQLRotationMultiUser/lambda_function.py b/SecretsManagerRDSMySQLRotationMultiUser/lambda_function.py index a51dbabe..6acd9795 100644 --- a/SecretsManagerRDSMySQLRotationMultiUser/lambda_function.py +++ b/SecretsManagerRDSMySQLRotationMultiUser/lambda_function.py @@ -180,7 +180,15 @@ def set_secret(service_client, arn, token): if not conn: logger.error("setSecret: Unable to log into database using current credentials for secret %s" % arn) raise ValueError("Unable to log into database using current credentials for secret %s" % arn) - conn.close() + # get hostname of existing user + try: + with conn.cursor() as cur: + cur.execute("SELECT CURRENT_USER()") + current_user_fullname = cur.fetchone()[0] + user_hostname = current_user_fullname.split("@")[1] + logger.info("User hostname detected: [%s]", user_hostname) + finally: + conn.close() # Use the master arn from the current secret to fetch master secret contents master_arn = current_dict['masterarn'] @@ -200,17 +208,35 @@ def set_secret(service_client, arn, token): # Now set the password to the pending password try: with conn.cursor() as cur: - cur.execute("SELECT User FROM mysql.user WHERE User = %s", pending_dict['username']) + cur.execute( + query="SELECT User FROM mysql.user WHERE User = %s AND Host = %s", + args=(pending_dict["username"], user_hostname), + ) # Create the user if it does not exist if cur.rowcount == 0: - cur.execute("CREATE USER %s IDENTIFIED BY %s", (pending_dict['username'], pending_dict['password'])) + cur.execute( + query="CREATE USER %s@%s IDENTIFIED BY %s", + args=( + pending_dict["username"], + user_hostname, + pending_dict["password"], + ), + ) # Copy grants to the new user - cur.execute("SHOW GRANTS FOR %s", current_dict['username']) + cur.execute( + query="SHOW GRANTS FOR %s@%s", + args=(current_dict["username"], user_hostname), + ) for row in cur.fetchall(): - grant = row[0].split(' TO ') - new_grant_escaped = grant[0].replace('%', '%%') # % is a special character in Python format strings. - cur.execute(new_grant_escaped + " TO %s", (pending_dict['username'],)) + grant = row[0].split(" TO ") + new_grant_escaped = grant[0].replace( + "%", "%%" + ) # % is a special character in Python format strings. + cur.execute( + query=new_grant_escaped + " TO %s@%s", + args=(pending_dict["username"], user_hostname), + ) # Get the version of MySQL cur.execute("SELECT VERSION()") @@ -218,23 +244,68 @@ def set_secret(service_client, arn, token): # Copy TLS options to the new user escaped_encryption_statement = get_escaped_encryption_statement(ver) - cur.execute("SELECT ssl_type, ssl_cipher, x509_issuer, x509_subject FROM mysql.user WHERE User = %s", current_dict['username']) + cur.execute( + query="SELECT ssl_type, ssl_cipher, x509_issuer, x509_subject FROM mysql.user WHERE User = %s AND Host = %s", + args=( + current_dict["username"], + user_hostname, + ), + ) tls_options = cur.fetchone() ssl_type = tls_options[0] if not ssl_type: - cur.execute(escaped_encryption_statement + " NONE", pending_dict['username']) - elif ssl_type == "ANY": - cur.execute(escaped_encryption_statement + " SSL", pending_dict['username']) - elif ssl_type == "X509": - cur.execute(escaped_encryption_statement + " X509", pending_dict['username']) + cur.execute( + query=escaped_encryption_statement + " NONE", + args=( + pending_dict["username"], + user_hostname, + ), + ) + elif "ANY" == ssl_type: + cur.execute( + query=escaped_encryption_statement + " SSL", + args=( + pending_dict["username"], + user_hostname, + ), + ) + elif "X509" == ssl_type: + cur.execute( + query=escaped_encryption_statement + " X509", + args=( + pending_dict["username"], + user_hostname, + ), + ) else: - cur.execute(escaped_encryption_statement + " CIPHER %s AND ISSUER %s AND SUBJECT %s", (pending_dict['username'], tls_options[1], tls_options[2], tls_options[3])) + cur.execute( + query=escaped_encryption_statement + + " CIPHER %s AND ISSUER %s AND SUBJECT %s", + args=( + pending_dict["username"], + user_hostname, + tls_options[1], + tls_options[2], + tls_options[3], + ), + ) # Set the password for the user and commit - password_option = get_password_option(ver) - cur.execute("SET PASSWORD FOR %s = " + password_option, (pending_dict['username'], pending_dict['password'])) + password_option = get_password_option(version=ver) + cur.execute( + query="SET PASSWORD FOR %s@%s = " + password_option, + args=( + pending_dict["username"], + user_hostname, + pending_dict["password"], + ), + ) conn.commit() - logger.info("setSecret: Successfully set password for %s in MySQL DB for secret arn %s." % (pending_dict['username'], arn)) + logger.info( + "setSecret: Successfully set password for %s in MySQL DB for secret arn %s.", + pending_dict["username"], + arn, + ) finally: conn.close() @@ -535,9 +606,9 @@ def get_escaped_encryption_statement(version): """ if version.startswith("5.6"): - return "GRANT USAGE ON *.* TO %s@'%%' REQUIRE" + return "GRANT USAGE ON *.* TO %s@%s REQUIRE" else: - return "ALTER USER %s@'%%' REQUIRE" + return "ALTER USER %s@%s REQUIRE" def is_rds_replica_database(replica_dict, master_dict):