Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update: create DB SQL for target Provider DB Service #50

Merged
merged 1 commit into from
Sep 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 71 additions & 10 deletions pkg/rdbms/mysql/mysql.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,30 @@ import (

// mysqlDBMS struct
type MysqlDBMS struct {
provider models.Provider
db *sql.DB
ctx context.Context
provider models.Provider
db *sql.DB
tartgetProvider models.Provider
ctx context.Context
}

type MysqlDBOption func(*MysqlDBMS)

func (d *MysqlDBMS) GetProvdier() models.Provider {
return d.provider
}

func (d *MysqlDBMS) SetProvdier(provider models.Provider) {
d.provider = provider
}

func (d *MysqlDBMS) GetTargetProvdier() models.Provider {
return d.tartgetProvider
}

func (d *MysqlDBMS) SetTargetProvdier(provider models.Provider) {
d.tartgetProvider = provider
}

func New(provider models.Provider, sqlDB *sql.DB, opts ...MysqlDBOption) *MysqlDBMS {
dms := &MysqlDBMS{
provider: provider,
Expand Down Expand Up @@ -143,6 +160,12 @@ func (d *MysqlDBMS) ShowCreateDBSql(dbName string, dbCreateSql *string) error {
*dbCreateSql = addCollateIfMissing(*dbCreateSql)
*dbCreateSql = EnsureCharsetAndCollate(*dbCreateSql, extractCharacterSet(*dbCreateSql), extractCollation(*dbCreateSql))

// If the target provider is NCP, modify the SQL to use NCP's specific procedure
if d.tartgetProvider == models.NCP {
dbName, charSet, collate := extractDatabaseInfo(*dbCreateSql)
*dbCreateSql = fmt.Sprintf("CALL sys.ncp_create_db('%s', '%s', '%s');", dbName, charSet, collate)
}

return nil
}

Expand All @@ -154,6 +177,8 @@ func (d *MysqlDBMS) ShowCreateTableSql(dbName, tableName string, tableCreateSql
if err := d.db.QueryRow(fmt.Sprintf("SHOW CREATE TABLE %s;", tableName)).Scan(&tableName, tableCreateSql); err != nil {
return err
}
*tableCreateSql = removeSequenceOption(*tableCreateSql)
*tableCreateSql = adjustColumnsToTimestamp(*tableCreateSql)
*tableCreateSql = ReplaceCharsetAndCollate(*tableCreateSql)
return nil
}
Expand Down Expand Up @@ -182,10 +207,10 @@ func (d *MysqlDBMS) GetInsert(dbName, tableName string, insertSql *[]string) err
}
defer selRows.Close()

data := []map[string]string{}
data := []map[string]sql.NullString{}

for selRows.Next() {
values := make([]string, len(columns))
values := make([]sql.NullString, len(columns))
valuePtrs := make([]interface{}, len(columns))
for i := range columns {
valuePtrs[i] = &values[i]
Expand All @@ -196,22 +221,30 @@ func (d *MysqlDBMS) GetInsert(dbName, tableName string, insertSql *[]string) err
return err
}

entry := make(map[string]string)
entry := make(map[string]sql.NullString)
for i, column := range columns {
val := values[i]
entry[column] = val
entry[column] = values[i]
}

data = append(data, entry)
}

for _, entry := range data {
values := []string{}
escapedColumns := []string{}
for _, column := range columns {
values = append(values, fmt.Sprintf("'%v'", entry[column]))
escapedColumn := escapeColumnName(column)
escapedColumns = append(escapedColumns, escapedColumn)
val := entry[column]
if val.Valid {
escapedValue := ReplaceEscapeString(val.String)
values = append(values, fmt.Sprintf("'%v'", escapedValue))
} else {
values = append(values, "NULL")
}
}

insertStatement := fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s);", tableName, strings.Join(columns, ", "), strings.Join(values, ", "))
insertStatement := fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s);", tableName, strings.Join(escapedColumns, ", "), strings.Join(values, ", "))
*insertSql = append(*insertSql, insertStatement)
}

Expand Down Expand Up @@ -247,6 +280,24 @@ func ReplaceCharsetAndCollate(sql string) string {
return sql
}

func ReplaceEscapeString(input string) string {
return strings.ReplaceAll(input, "'", "''")
}

func adjustColumnsToTimestamp(sql string) string {
// Use a regular expression to find all columns that use DEFAULT current_timestamp()
re := regexp.MustCompile("`[^`]+`\\s+[^,]+DEFAULT\\s+current_timestamp\\(\\)")

// Replace these columns with TIMESTAMP DEFAULT current_timestamp()
modifiedSQL := re.ReplaceAllStringFunc(sql, func(match string) string {
// Retain the column name and change the rest of the definition to TIMESTAMP
columnName := strings.Split(match, " ")[0] // The first element is the column name
return fmt.Sprintf("%s TIMESTAMP DEFAULT current_timestamp()", columnName)
})

return modifiedSQL
}

// Extract database information
func extractDatabaseInfo(sql string) (string, string, string) {
dbName := extractDatabaseName(sql)
Expand Down Expand Up @@ -284,3 +335,13 @@ func extractCollation(sql string) string {
}
return ""
}

// remove Sequence
func removeSequenceOption(sql string) string {
return strings.Replace(sql, " SEQUENCE=1", "", -1)
}

// escape Reserve Word
func escapeColumnName(columnName string) string {
return fmt.Sprintf("`%s`", columnName)
}
6 changes: 6 additions & 0 deletions service/rdbc/rdbc.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"fmt"
"strings"

"github.com/cloud-barista/mc-data-manager/models"
"github.com/sirupsen/logrus"
)

Expand All @@ -33,6 +34,10 @@ const (
//
// Configure the interface to make it easier for other DBs to apply in the future
type RDBMS interface {
GetProvdier() models.Provider
SetProvdier(provider models.Provider)
GetTargetProvdier() models.Provider
SetTargetProvdier(provider models.Provider)
Exec(query string) error
ListDB(dst *[]string) error
DeleteDB(dbName string) error
Expand Down Expand Up @@ -131,6 +136,7 @@ func (rdb *RDBController) Copy(dst *RDBController) error {

for _, db := range dbList {
sql = ""
rdb.client.SetTargetProvdier(dst.client.GetProvdier())
if err := rdb.Get(db, &sql); err != nil {
rdb.logWrite("Error", "Get error", err)
return err
Expand Down
Loading