Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Add support for constraints in databricks_sql_table resource #4205

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 126 additions & 4 deletions catalog/resource_sql_table.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,16 @@ type SqlColumnInfo struct {
TypeJson string `json:"type_json,omitempty" tf:"computed"`
}

type ConstraintInfo struct {
Name string `json:"name"`
Type string `json:"type"`
KeyColumns []string `json:"key_columns,omitempty"`
ParentTable string `json:"parent_table,omitempty"`
ParentColumns []string `json:"parent_columns,omitempty"`
CheckFormula string `json:"check_formula,omitempty"`
Rely bool `json:"rely,omitempty" tf:"default:false"`
}

type TypeJson struct {
Metadata map[string]any `json:"metadata,omitempty"`
}
Expand All @@ -51,6 +61,7 @@ type SqlTableInfo struct {
ColumnInfos []SqlColumnInfo `json:"columns,omitempty" tf:"alias:column,computed"`
Partitions []string `json:"partitions,omitempty" tf:"force_new"`
ClusterKeys []string `json:"cluster_keys,omitempty"`
Constraints []ConstraintInfo `json:"constraints,omitempty" tf:"alias:constraint"`
StorageLocation string `json:"storage_location,omitempty" tf:"suppress_diff"`
StorageCredentialName string `json:"storage_credential_name,omitempty" tf:"force_new"`
ViewDefinition string `json:"view_definition,omitempty"`
Expand Down Expand Up @@ -242,6 +253,50 @@ func (ti *SqlTableInfo) serializeColumnInfos() string {
return strings.Join(columnFragments[:], ", ") // id INT NOT NULL, name STRING, age INT
}

func (ti *ConstraintInfo) serializePrimaryKeyConstraint() string {
constraint_clause := fmt.Sprintf("CONSTRAINT %s PRIMARY KEY(%s)", ti.getWrappedConstraintName(), ti.getWrappedKeyColumnNames())
if ti.Rely {
constraint_clause += " RELY"
}
return constraint_clause
}

func (ti *ConstraintInfo) serializeForeignKeyConstraint() string {
constraint_clause := fmt.Sprintf("CONSTRAINT %s FOREIGN KEY(%s) REFERENCES %s", ti.getWrappedConstraintName(), ti.getWrappedKeyColumnNames(), ti.ParentTable)
if len(ti.ParentColumns) > 0 {
constraint_clause += fmt.Sprintf("(%s)", ti.getWrappedParentColumnNames())
}
if ti.Rely {
constraint_clause += " RELY"
}
return constraint_clause
}

func (ti *ConstraintInfo) serializeCheckConstraint() string {
constraint_clause := fmt.Sprintf("CONSTRAINT %s CHECK (%s)", ti.getWrappedConstraintName(), ti.CheckFormula)
if ti.Rely {
constraint_clause += " RELY"
}
return constraint_clause
}

func (ti *SqlTableInfo) serializeConstraintsForTableCreateStatement() (string, error) {
constraintFragments := make([]string, len(ti.Constraints))

for i, constraint := range ti.Constraints {
if constraint.Type == "PRIMARY KEY" {
constraintFragments[i] = constraint.serializePrimaryKeyConstraint()
} else if constraint.Type == "FOREIGN KEY" {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we only support PRIMARY KEY and FOREIGN KEY, we should fail client side if some other value is used.

constraintFragments[i] = constraint.serializeForeignKeyConstraint()
} else {
err := fmt.Errorf("constraint of type %s is not supported for CREATE TABLE statement", constraint.Type)
return "", err
}
}

return strings.Join(constraintFragments[:], ", "), nil // CONSTRAINT `pk`` PRIMARY KEY (`id`, `nickname`), CONSTRAINT `fk` FOREIGN KEY (`player_id`) REFERENCES players
}

func (ti *SqlTableInfo) serializeProperties() string {
propsMap := make([]string, 0, len(ti.Properties))
for key, value := range ti.Properties {
Expand Down Expand Up @@ -275,7 +330,7 @@ func (ti *SqlTableInfo) getTableTypeString() string {
return "TABLE"
}

func (ti *SqlTableInfo) buildTableCreateStatement() string {
func (ti *SqlTableInfo) buildTableCreateStatement() (string, error) {
statements := make([]string, 0, 10)

isView := ti.TableType == "VIEW"
Expand All @@ -290,7 +345,15 @@ func (ti *SqlTableInfo) buildTableCreateStatement() string {
statements = append(statements, fmt.Sprintf("CREATE %s%s %s", externalFragment, createType, ti.SQLFullName()))

if len(ti.ColumnInfos) > 0 {
statements = append(statements, fmt.Sprintf(" (%s)", ti.serializeColumnInfos()))
columnInfosClause := ti.serializeColumnInfos()
if len(ti.Constraints) > 0 {
constraintStatement, err := ti.serializeConstraintsForTableCreateStatement()
if err != nil {
return "", err
}
columnInfosClause += fmt.Sprintf(", %s", constraintStatement)
}
statements = append(statements, fmt.Sprintf(" (%s)", columnInfosClause))
}

if !isView {
Expand Down Expand Up @@ -329,7 +392,7 @@ func (ti *SqlTableInfo) buildTableCreateStatement() string {

statements = append(statements, ";")

return strings.Join(statements, "")
return strings.Join(statements, ""), nil
}

// Wrapping the column name with backticks to avoid special character messing things up.
Expand All @@ -342,6 +405,21 @@ func (ti *SqlTableInfo) getWrappedClusterKeys() string {
return "`" + strings.Join(ti.ClusterKeys, "`,`") + "`"
}

// Wrapping the constraint name with backticks to avoid special character messing things up.
func (ci ConstraintInfo) getWrappedConstraintName() string {
return fmt.Sprintf("`%s`", ci.Name)
}

// Wrapping constraint column names with backticks to avoid special character messing things up.
func (ci ConstraintInfo) getWrappedKeyColumnNames() string {
return "`" + strings.Join(ci.KeyColumns, "`,`") + "`"
}

// Wrapping parent column name with backticks to avoid special character messing things up.
func (ci ConstraintInfo) getWrappedParentColumnNames() string {
return "`" + strings.Join(ci.ParentColumns, "`,`") + "`"
}

func (ti *SqlTableInfo) getStatementsForColumnDiffs(oldti *SqlTableInfo, statements []string, typestring string) []string {
if len(ti.ColumnInfos) != len(oldti.ColumnInfos) {
statements = ti.addOrRemoveColumnStatements(oldti, statements, typestring)
Expand Down Expand Up @@ -413,6 +491,45 @@ func (ti *SqlTableInfo) alterExistingColumnStatements(oldti *SqlTableInfo, state
return statements
}

func (ti *SqlTableInfo) addOrRemoveConstraintStatements(oldti *SqlTableInfo, statements []string, typestring string) []string {
nameToOldConstraint := make(map[string]ConstraintInfo)
nameToNewConstraint := make(map[string]ConstraintInfo)
for _, ci := range oldti.Constraints {
nameToOldConstraint[ci.Name] = ci
}
for _, newCi := range ti.Constraints {
nameToNewConstraint[newCi.Name] = newCi
}

removeConstraintStatements := make([]string, 0)

for name, oldCi := range nameToOldConstraint {
if _, exists := nameToNewConstraint[name]; !exists {
// Remove old constraint if old constraint is no longer found in the config.
removeConstraintStatements = append(removeConstraintStatements, oldCi.getWrappedConstraintName())
}
}
for _, removeStatement := range removeConstraintStatements {
statements = append(statements, fmt.Sprintf("ALTER %s %s DROP CONSTRAINT IF EXISTS %s", typestring, ti.SQLFullName(), removeStatement))
}

for _, newCi := range ti.Constraints {
if _, exists := nameToOldConstraint[newCi.Name]; !exists {
var newConstraintStatement string
if newCi.Type == "PRIMARY KEY" {
newConstraintStatement = newCi.serializePrimaryKeyConstraint()
} else if newCi.Type == "FOREIGN KEY" {
newConstraintStatement = newCi.serializeForeignKeyConstraint()
} else if newCi.Type == "CHECK" {
newConstraintStatement = newCi.serializeCheckConstraint()
}
statements = append(statements, fmt.Sprintf("ALTER %s %s ADD %s", typestring, ti.SQLFullName(), newConstraintStatement))
}
}

return statements
}

func (ti *SqlTableInfo) diff(oldti *SqlTableInfo) ([]string, error) {
statements := make([]string, 0)
typestring := ti.getTableTypeString()
Expand Down Expand Up @@ -454,6 +571,7 @@ func (ti *SqlTableInfo) diff(oldti *SqlTableInfo) ([]string, error) {
}

statements = ti.getStatementsForColumnDiffs(oldti, statements, typestring)
statements = ti.addOrRemoveConstraintStatements(oldti, statements, typestring)

return statements, nil
}
Expand All @@ -473,7 +591,11 @@ func (ti *SqlTableInfo) updateTable(oldti *SqlTableInfo) error {
}

func (ti *SqlTableInfo) createTable() error {
return ti.applySql(ti.buildTableCreateStatement())
tableCreateStatement, err := ti.buildTableCreateStatement()
if err != nil {
return err
}
return ti.applySql(tableCreateStatement)
}

func (ti *SqlTableInfo) deleteTable() error {
Expand Down
Loading
Loading