Skip to content

Commit

Permalink
Use keras 3 friendly syntax
Browse files Browse the repository at this point in the history
  • Loading branch information
nzw0301 committed Mar 1, 2024
1 parent 1486d10 commit 371c645
Showing 1 changed file with 6 additions and 5 deletions.
11 changes: 6 additions & 5 deletions mlflow/keras_mlflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,14 @@
import optuna
from optuna.integration.mlflow import MLflowCallback

from keras.backend import clear_session
from keras.layers import Dense
from keras.layers import Input
from keras.models import Sequential
from keras.optimizers import SGD
from sklearn.datasets import load_wine
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from tensorflow.keras.backend import clear_session
from tensorflow.keras.layers import Dense
from tensorflow.keras.models import Sequential
from tensorflow.keras.optimizers import SGD


TEST_SIZE = 0.25
Expand All @@ -38,12 +39,12 @@ def standardize(data):

def create_model(num_features, trial):
model = Sequential()
model.add(Input(shape=(num_features,)))
model.add(
Dense(
num_features,
activation="relu",
kernel_initializer="normal",
input_shape=(num_features,),
)
),
model.add(Dense(16, activation="relu", kernel_initializer="normal"))
Expand Down

0 comments on commit 371c645

Please sign in to comment.