Skip to content

Commit

Permalink
codegen tensorflow use attr dict (#1094)
Browse files Browse the repository at this point in the history
* codegen tensorflow use attr dict

* update

* fix return
  • Loading branch information
typhoonzero authored Oct 30, 2019
1 parent 5a64bd6 commit 26d63fc
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 1 deletion.
27 changes: 26 additions & 1 deletion pkg/sql/codegen/attribute/attribute.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ const (
String
// IntList indicates the corresponding attribute is a list of integers
IntList
// Unknown type indicates that the attribute type is dynamically determined.
Unknown
)

// Dictionary contains the description of all attributes
Expand Down Expand Up @@ -68,9 +70,32 @@ func (t Type) String() string {
// 2. Customer checker
func (d Dictionary) Validate(attrs map[string]interface{}) error {
for k, v := range attrs {
var desc *Description
desc, ok := d[k]
if !ok {
return fmt.Errorf(errUnsupportedAttribute, k)
// Support attribute definition like "model.*" to match attributes start with "model"
keyParts := strings.Split(k, ".")
if len(keyParts) == 2 {
wildCard := fmt.Sprintf("%s.*", keyParts[0])
descWild, okWildCard := d[wildCard]
if okWildCard {
desc = descWild
} else {
return fmt.Errorf(errUnsupportedAttribute, k)
}
} else {
return fmt.Errorf(errUnsupportedAttribute, k)
}

}
// unknown type of attribute do not need to run validate
if desc.Type == Unknown {
if desc.Checker != nil {
if err := desc.Checker(v); err != nil {
return err
}
}
continue
}
switch v.(type) {
case int, int32, int64:
Expand Down
8 changes: 8 additions & 0 deletions pkg/sql/codegen/attribute/checker.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,3 +99,11 @@ func IntRangeChecker(lower, upper *int, includeLower, includeUpper bool) func(in

return checker
}

// EmptyChecker returns a checker function that do **not** check the input.
func EmptyChecker() func(interface{}) error {
checker := func(e interface{}) error {
return nil
}
return checker
}
23 changes: 23 additions & 0 deletions pkg/sql/codegen/tensorflow/codegen.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,28 @@ import (

pb "sqlflow.org/sqlflow/pkg/server/proto"
"sqlflow.org/sqlflow/pkg/sql/codegen"
"sqlflow.org/sqlflow/pkg/sql/codegen/attribute"
)

func newFloat32(f float32) *float32 {
return &f
}

func newInt(i int) *int {
return &i
}

var attributeDictionary = attribute.Dictionary{
"train.batch_size": {attribute.Int, `[default=1]
The training batch size.
range: [0,Infinity]`, attribute.IntRangeChecker(newInt(0), nil, false, false)},
"train.epoch": {attribute.Int, `[default=1]
Number of epochs the training will run.
range: [1, Infinity]`, attribute.IntRangeChecker(newInt(0), nil, false, false)},
"model.*": {attribute.Unknown, `parameters defined by the model implementation, e.g. https://www.tensorflow.org/api_docs/python/tf/estimator/DNNClassifier#__init__, customized model example: https://github.com/sql-machine-learning/models/blob/develop/sqlflow_models/dnnclassifier.py#L4`,
attribute.EmptyChecker()},
}

func intArrayToJSONString(ia []int) string {
return strings.Join(strings.Split(fmt.Sprint(ia), " "), ",")
}
Expand Down Expand Up @@ -141,6 +161,9 @@ func isKerasModel(estimator string) (bool, string) {

// Train generates a Python program for train a TensorFlow model.
func Train(ir *codegen.TrainIR) (string, error) {
if err := attributeDictionary.Validate(ir.Attributes); err != nil {
return "", err
}
trainParams := make(map[string]interface{})
modelParams := make(map[string]interface{})
for attrKey, attr := range ir.Attributes {
Expand Down

0 comments on commit 26d63fc

Please sign in to comment.