Skip to content

Commit

Permalink
Fix keras debugger with new tf2
Browse files Browse the repository at this point in the history
  • Loading branch information
MarsTechHAN committed Feb 13, 2022
1 parent 741f654 commit 9d9380d
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions keras2ncnn/keras_debugger.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ class KerasDebugger:
'''

input_extractor_template = '\tncnn::Mat $layer_name_rep$_blob;\n'\
'\t$layer_name_rep$_blob.create($input_shape$, 4u);\n'\
'\t$layer_name_rep$_blob.create($input_shape$);\n'\
'\trand_mat($layer_name_rep$_blob);\n'\
'\tex.input("$layer_name$_blob", $layer_name_rep$_blob);\n\n'

Expand Down Expand Up @@ -279,7 +279,7 @@ def run_debug(self):

def decode(self, h5_file, keras_graph, graph_seq):
import numpy as np # pylint: disable=import-outside-toplevel
from tensorflow.python import keras # pylint: disable=import-outside-toplevel
# from tensorflow import keras # pylint: disable=import-outside-toplevel
from tensorflow.python.keras import backend as K # pylint: disable=import-outside-toplevel
from tensorflow.python.keras.models import model_from_json
K.set_learning_phase(0)
Expand Down Expand Up @@ -366,7 +366,7 @@ def is_log_file(x): return '.dat' in x
for func_idx in range(len(output_node_names)):

functor = K.function(inputs[func_idx], output_nodes[func_idx])
layer_outs = functor(input_images + [1, ])
layer_outs = functor(input_images )
keras_layer_dumps_list.append(
dict(zip(output_node_names[func_idx], layer_outs)))

Expand Down Expand Up @@ -476,4 +476,4 @@ def is_log_file(x): return '.dat' in x

print(
'Top-k:\nKeras Top-k: \t%s\nncnn Top-k: \t%s' %
(keras_topk_str, ncnn_topk_str))
(keras_topk_str, ncnn_topk_str))

0 comments on commit 9d9380d

Please sign in to comment.