Skip to content
This repository has been archived by the owner on Jan 24, 2024. It is now read-only.

Commit

Permalink
Add doc cherry-pick (#631)
Browse files Browse the repository at this point in the history
* Add doc (#623)

* Refine load_paddle_model doc (#626)

* refine doc and close LOG (#632)
  • Loading branch information
haozech authored Dec 9, 2021
1 parent fcde5e8 commit de53fce
Show file tree
Hide file tree
Showing 32 changed files with 413 additions and 168 deletions.
59 changes: 40 additions & 19 deletions build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,17 @@ function gpu_on {
cudnn_config=ON
}

function test_doc {
mkdir -p $build_dir
cd $build_dir
export runtime_include_dir=$workspace/cinn/runtime/cuda

prepare_ci
cmake_
build
make_doc
}

function cudnn_off {
cudnn_config=OFF
}
Expand Down Expand Up @@ -94,36 +105,46 @@ function prepare_ci {
pip install pre-commit
pip install clang-format==9.0
pip install wheel
pip install sphinx==3.3.1 sphinx_gallery==0.8.1 recommonmark==0.6.0 exhale scipy breathe==4.24.0 matplotlib
pip install sphinx==3.3.1 sphinx_gallery==0.8.1 recommonmark==0.6.0 exhale scipy breathe==4.24.0 matplotlib sphinx_rtd_theme
pip install paddlepaddle-gpu==2.1.2.post101 -f https://www.paddlepaddle.org.cn/whl/linux/mkl/avx/stable.html
}

function make_doc {
function prepare_doc_model_file {
proxy_off
cd $workspace/tutorials
if [[ -f "ResNet18.tar.gz" ]]; then
echo "model file for tutorials already downloaded."
elif [[ -f "$build_dir/thirds/ResNet18.tar.gz" ]]; then
rm -rf $workspace/tutorials/ResNet18
ln -s $build_dir/thirds/ResNet18 $workspace/tutorials/ResNet18
local tar_file=$1
if [[ -f "$tar_file.tar.gz" ]]; then
echo "model file $tar_file.tar.gz for tutorials already downloaded."
elif [[ -f "$build_dir/thirds/$tar_file.tar.gz" ]]; then
rm -rf $workspace/tutorials/$tar_file
ln -s $build_dir/thirds/$tar_file $workspace/tutorials/$tar_file
else
wget http://paddle-inference-dist.bj.bcebos.com/CINN/ResNet18.tar.gz
tar -zxvf ResNet18.tar.gz
wget https://paddle-inference-dist.bj.bcebos.com/CINN/$tar_file.tar.gz
tar -zxvf $tar_file.tar.gz
fi
}

function make_doc {
proxy_off
cd $workspace/tutorials
prepare_doc_model_file ResNet50
prepare_doc_model_file MobileNetV2
prepare_doc_model_file EfficientNet
prepare_doc_model_file FaceDet

if [[ $cuda_config == "ON" && ! -d "./is_cuda" ]]; then
mkdir is_cuda
fi

if [[ $cuda_config == "OFF" && -d "./is_cuda" ]]; then
rm -rf ./is_cuda
fi
cd $build_dir
rm -f $workspace/python/cinn/core_api.so
ln -s $build_dir/cinn/pybind/core_api.so $workspace/python/cinn/
cd $workspace/docs
mkdir -p docs/source/cpp
cat $workspace/tutorials/matmul.cc | python${py_version} $workspace/tools/gen_c++_tutorial.py > $workspace/docs/source/matmul.md
cat $workspace/tutorials/matmul.cc | python${py_version} $workspace/tools/gen_c++_tutorial.py > $workspace/docs/source/matmul.md
cat $workspace/tutorials/load_paddle_model.cc | python${py_version} $workspace/tools/gen_c++_tutorial.py > $workspace/docs/source/load_paddle_model.md
make html
if [[ $cuda_config == "ON" && -d "./is_cuda" ]]; then
rm -rf $workspace/tutorials/is_cuda
fi
}

function cmake_ {
Expand Down Expand Up @@ -308,6 +329,10 @@ function main {
run_test
shift
;;
test_doc)
test_doc
shift
;;
ci)
CI
shift
Expand All @@ -320,10 +345,6 @@ function main {
prepare_model
shift
;;
make_doc)
make_doc
shift
;;
esac
done
}
Expand Down
14 changes: 7 additions & 7 deletions cinn/backends/codegen_cuda_dev_test.cc
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ TEST(CodeGenCUDA, basic) {

CodeGenCUDA_Dev codegen(target);

auto func = Lower("elementwise_add", stages, {A, B, C});
auto func = Lower("elementwise_mul", stages, {A, B, C});

auto compiled = codegen.Compile(func);

Expand All @@ -115,7 +115,7 @@ TEST(CodeGenCUDA, Module_output) {

CodeGenCUDA_Dev codegen(target);

auto func = Lower("elementwise_add", stages, {A, B, C});
auto func = Lower("elementwise_mul", stages, {A, B, C});

Module::Builder builder("module", target);
builder.AddFunction(func);
Expand Down Expand Up @@ -149,7 +149,7 @@ TEST(CodeGenCUDA2, test_of_cacheread) {
stages[B_cache]->ComputeAt(stages[C], 1);
CodeGenCUDA_Dev codegen(target);

auto func = Lower("elementwise_add", stages, {A, B, C});
auto func = Lower("elementwise_mul", stages, {A, B, C});

Module::Builder builder("module", target);
builder.AddFunction(func);
Expand Down Expand Up @@ -181,7 +181,7 @@ TEST(CodeGenCUDA2, test_of_cacheread) {

dim3 grid(10, 1, 1);
dim3 block(10, 1, 1);
cuda_module.LaunchKernel(0, "elementwise_add", grid, block, args);
cuda_module.LaunchKernel(0, "elementwise_mul", grid, block, args);

CUDA_CALL(cudaMemcpy(host_data3.data(),
reinterpret_cast<void*>(Cd),
Expand Down Expand Up @@ -221,7 +221,7 @@ TEST(CodeGenCUDA2, test_of_splitcudakernel) {

CodeGenCUDA_Dev codegen(target);

auto func = lang::LowerVec("elementwise_add", stages, {A, B, C, D}, {}, {}, nullptr, target);
auto func = lang::LowerVec("elementwise_mul_and_add", stages, {A, B, C, D}, {}, {}, nullptr, target);

Module::Builder builder("module", target);
for (auto& i : func) {
Expand Down Expand Up @@ -251,15 +251,15 @@ typedef char int8_t;
__global__
void __launch_bounds__(200) elementwise_add(const float* __restrict__ X, const float* __restrict__ Y, float* __restrict__ C)
void __launch_bounds__(200) elementwise_mul_and_add(const float* __restrict__ X, const float* __restrict__ Y, float* __restrict__ C)
{
if (((int)blockIdx.x < 100)) {
if (((int)threadIdx.x < 200)) {
C[((200 * (int)blockIdx.x) + (int)threadIdx.x)] = (X[((200 * (int)blockIdx.x) + (int)threadIdx.x)] * Y[((200 * (int)blockIdx.x) + (int)threadIdx.x)]);
};
};
}__global__
void __launch_bounds__(200) elementwise_add_1(const float* __restrict__ X, const float* __restrict__ Y, const float* __restrict__ C, float* __restrict__ D)
void __launch_bounds__(200) elementwise_mul_and_add_1(const float* __restrict__ X, const float* __restrict__ Y, const float* __restrict__ C, float* __restrict__ D)
{
if (((int)blockIdx.x < 100)) {
if (((int)threadIdx.x < 200)) {
Expand Down
6 changes: 3 additions & 3 deletions cinn/backends/compiler.cc
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -70,14 +70,14 @@ void Compiler::CompileCudaModule(const Module& module, const std::string& code,
auto _host_module_device_module_ = SplitCudaAndHostModule(module); // NOLINT
auto& host_module = std::get<0>(_host_module_device_module_);
auto& device_module = std::get<1>(_host_module_device_module_);
LOG(INFO) << "[CUDA] host module:\n" << host_module;
VLOG(3) << "[CUDA] host module:\n" << host_module;

{ // compile cuda device
LOG(INFO) << "[CUDA] device module:\n" << device_module;
VLOG(3) << "[CUDA] device module:\n" << device_module;
CodeGenCUDA_Dev codegen(target_);
auto source_code = codegen.Compile(device_module);
if (!code.empty()) source_code = code;
LOG(INFO) << "[CUDA] source code:\n" << source_code;
VLOG(3) << "[CUDA] source code:\n" << source_code;
using runtime::cuda::CUDAModule;

backends::NVRTC_Compiler compiler;
Expand Down
10 changes: 5 additions & 5 deletions cinn/backends/llvm/execution_engine.cc
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ std::unique_ptr<llvm::MemoryBuffer> NaiveObjectCache::getObject(const llvm::Modu
return nullptr;
}

LOG(INFO) << "Object for " << m->getModuleIdentifier() << " loaded from cache.";
VLOG(3) << "Object for " << m->getModuleIdentifier() << " loaded from cache.";
return llvm::MemoryBuffer::getMemBuffer(it->second->getMemBufferRef());
}

Expand Down Expand Up @@ -178,25 +178,25 @@ void ExecutionEngine::Link(const ir::Module &module) {

decltype(auto) es = jit_->getExecutionSession();
if (false) {
LOG(INFO) << "======= dump jit execution session ======";
VLOG(3) << "======= dump jit execution session ======";
std::string buffer;
llvm::raw_string_ostream os(buffer);
es.dump(os);
os.flush();
LOG(INFO) << buffer;
VLOG(3) << buffer;
}
}

bool ExecutionEngine::AddModule(std::unique_ptr<llvm::Module> module, std::unique_ptr<llvm::LLVMContext> context) {
module->setDataLayout(jit_->getDataLayout());
if (false) {
LOG(INFO) << "======= dump jit lib ==========";
VLOG(3) << "======= dump jit lib ==========";
std::string buffer;
llvm::raw_string_ostream os(buffer);
module->print(os, {});
// main_jd_->dump(os);
os.flush();
LOG(INFO) << buffer;
VLOG(3) << buffer;
}
llvm::orc::ThreadSafeContext tsc(std::move(context));
llvm::orc::ThreadSafeModule tsm(std::move(module), std::move(tsc));
Expand Down
6 changes: 3 additions & 3 deletions cinn/backends/llvm/simple_jit.cc
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,8 @@ void SimpleJIT::AddModule(std::unique_ptr<llvm::Module> module, bool optimize) {
module_pass_manager.run(*module, module_analysis_manager);
}

LOG(INFO) << "jit target: " << jit_->getDataLayout().getStringRepresentation();
LOG(INFO) << "module target: " << module->getDataLayout().getStringRepresentation();
VLOG(3) << "jit target: " << jit_->getDataLayout().getStringRepresentation();
VLOG(3) << "module target: " << module->getDataLayout().getStringRepresentation();

llvm::orc::ThreadSafeModule tsm(std::move(module), context_);
llvm::cantFail(jit_->addIRModule(std::move(tsm)));
Expand All @@ -82,7 +82,7 @@ void SimpleJIT::AddModule(std::unique_ptr<llvm::Module> module, bool optimize) {
llvm::raw_string_ostream os(buffer);
jit_->getExecutionSession().dump(os);
os.flush();
LOG(INFO) << "compiled jit:\n" << buffer;
VLOG(3) << "compiled jit:\n" << buffer;
}
}

Expand Down
Empty file modified cinn/backends/llvm/simple_jit.h
100644 → 100755
Empty file.
2 changes: 1 addition & 1 deletion cinn/backends/nvrtc_util.cc
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ std::string NVRTC_Compiler::CompilePTX(const std::string& code, bool include_hea
for (const auto& option : compile_options) {
param_cstrings.push_back(option.c_str());
}
LOG(INFO) << "compile options: " << utils::Join(compile_options, " ");
VLOG(3) << "compile options: " << utils::Join(compile_options, " ");
NVRTC_CALL(nvrtcCreateProgram(&prog, code.c_str(), nullptr, 0, nullptr, nullptr));
nvrtcResult compile_res = nvrtcCompileProgram(prog, param_cstrings.size(), param_cstrings.data());

Expand Down
4 changes: 2 additions & 2 deletions cinn/common/cas.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2005,8 +2005,8 @@ Expr CasSimplifyMutator::FurtherSimplifyFracWithInterval(
auto it = var_intervals.find(bv->name);
auto ai_abs = std::abs(ai->value);
if (it != var_intervals.end()) {
LOG(INFO) << "found " << bv->name << " " << it->second << " "
<< " ai " << ai_abs;
VLOG(3) << "found " << bv->name << " " << it->second << " "
<< " ai " << ai_abs;
}
if (it != var_intervals.end() && std::abs(it->second.r) > ai_abs && std::abs(it->second.l) > ai_abs) {
return make_const(a.type(), 0);
Expand Down
2 changes: 1 addition & 1 deletion cinn/common/ir_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ Expr RampRelatedMul(Expr a, Expr b) {
CHECK_EQ(a_broadcast->lanes, b_broadcast->lanes);
return ir::Broadcast::Make(a_broadcast->value * b_broadcast->value, a_broadcast->lanes);
} else {
LOG(INFO) << "a,b: " << a << " " << b;
VLOG(3) << "a,b: " << a << " " << b;
CINN_NOT_IMPLEMENTED
}
}
Expand Down
2 changes: 1 addition & 1 deletion cinn/frontend/computation.cc
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ std::shared_ptr<CinnComputation> CinnComputation::CompilePaddleModel(
}
program->SetInputs({input_vars});
program->Validate();
LOG(INFO) << "program:\n" << *program;
VLOG(3) << "program:\n" << *program;

for (auto &name : fetch_names) {
output_vars.push_back(varmap.at(name));
Expand Down
2 changes: 1 addition & 1 deletion cinn/frontend/interpreter.cc
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ void Interpreter::Impl::Build(const std::vector<std::string>& input_names,
program_->SetInputs({input_vars});
program_->Validate();

LOG(INFO) << "Program:\n" << *program_;
VLOG(3) << "Program:\n" << *program_;

auto graph = std::make_shared<hlir::framework::Graph>(*program_, target);
graph->attrs["model_name"] = std::make_shared<absl::any>(model_name);
Expand Down
6 changes: 3 additions & 3 deletions cinn/frontend/paddle/model_parser.cc
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -222,9 +222,9 @@ void LoadModelPb(const std::string &model_dir,
CHECK(cpp_prog);
CHECK(scope);
cpp_prog->ClearBlocks();
LOG(INFO) << "model_dir is: " << model_dir;
LOG(INFO) << "model_file is: " << model_file;
LOG(INFO) << "param_file is: " << param_file;
VLOG(3) << "model_dir is: " << model_dir;
VLOG(3) << "model_file is: " << model_file;
VLOG(3) << "param_file is: " << param_file;
// Load model
VLOG(4) << "Start load model program...";
std::string prog_path = model_dir + "/__model__";
Expand Down
4 changes: 2 additions & 2 deletions cinn/hlir/pe/nn.cc
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -253,11 +253,11 @@ std::vector<ir::Tensor> Conv2d_NCHW(const ir::Tensor &input,
std::to_string(output_shape_int[1]) + " " + std::to_string(output_shape_int[2]) + " " +
std::to_string(output_shape_int[3]);
if (res.count(key) > 0) {
LOG(INFO) << "Find saved winograd_conv2d schedule param! key is: " << key;
VLOG(3) << "Find saved winograd_conv2d schedule param! key is: " << key;
return Conv2d_winograd_NCHW(
input, weights, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, output_name);
}
LOG(INFO) << "Didn't find saved winograd_conv2d schedule param! key is: " << key;
VLOG(3) << "Didn't find saved winograd_conv2d schedule param! key is: " << key;
}
ir::Tensor input_pad;
if (pad_h == 0 && pad_w == 0) {
Expand Down
4 changes: 2 additions & 2 deletions cinn/optim/replace_var_with_expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ struct ReplaceVarIndexOfCacheMutator : public ir::IRMutator<> {
if (tensor_shape[index].is_constant() && tensor_shape[index].get_constant() <= 0) {
tensor_shape[index] = Expr(1);
} else if (!tensor_shape[index].is_constant()) {
LOG(INFO) << "Index is not constant: " << tensor_shape[index] << " and it will be replaced to 1";
VLOG(3) << "Index is not constant: " << tensor_shape[index] << " and it will be replaced to 1";
tensor_shape[index] = Expr(1);
}
(*global_tensor_map_).at(tensor_name)->shape = tensor_shape;
Expand All @@ -239,7 +239,7 @@ struct ReplaceVarIndexOfCacheMutator : public ir::IRMutator<> {
VLOG(3) << i;
}
} else {
LOG(INFO) << "extent not defined";
VLOG(3) << "extent not defined";
}
}

Expand Down
2 changes: 1 addition & 1 deletion cinn/poly/compute_at_transform.cc
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ void ComputeAtTransform::DisplayC(isl_map* pschedule, isl_map* cschedule) {
auto* build = isl_ast_build_from_context(context.release());
auto* node = isl_ast_build_node_from_schedule_map(build, intersect_schedule.release());

LOG(INFO) << "code:\n\n" << isl_ast_node_to_C_str(node);
VLOG(3) << "code:\n\n" << isl_ast_node_to_C_str(node);

isl_ast_node_free(node);
}
Expand Down
6 changes: 3 additions & 3 deletions cinn/poly/stage.cc
Original file line number Diff line number Diff line change
Expand Up @@ -970,12 +970,12 @@ void Stage::Vectorize(int level, int factor) {
CHECK_LT(level, n_out_dims());
CHECK_GT(factor, 0);
if (factor == 1) {
LOG(INFO) << "Vectorize-factor 1 has no sense, skip it";
VLOG(3) << "Vectorize-factor 1 has no sense, skip it";
return;
}
auto transformed_domain = this->transformed_domain();
if (isl_is_removed_axis(transformed_domain.get(), level)) {
LOG(INFO) << "Vectorizing for-1 has no sense, skip it";
VLOG(3) << "Vectorizing for-1 has no sense, skip it";
return;
}
int removed_axes_counts = isl_get_precending_removed_axes_counts(transformed_domain.get(), level);
Expand Down Expand Up @@ -1008,7 +1008,7 @@ void Stage::Parallel(int level) {
auto transformed_domain = this->transformed_domain();
VLOG(3) << "transformed_domain" << transformed_domain;
if (isl_is_removed_axis(transformed_domain.get(), level)) {
LOG(INFO) << "Paralleling for-1 has no sense, skip it";
VLOG(3) << "Paralleling for-1 has no sense, skip it";
return;
}
int removed_axes_counts = isl_get_precending_removed_axes_counts(transformed_domain.get(), level);
Expand Down
4 changes: 2 additions & 2 deletions cinn/pybind/frontend.cc
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ void BindFrontend(pybind11::module *m) {
CINN_NOT_IMPLEMENTED
}
}
LOG(INFO) << info;
VLOG(3) << info;
program->ExecuteTest(repeat_);
auto out = scope->GetTensor(tensor_out->id);
return out;
Expand Down Expand Up @@ -268,7 +268,7 @@ void BindFrontend(pybind11::module *m) {
CINN_NOT_IMPLEMENTED
}
}
LOG(INFO) << info;
VLOG(3) << info;
program->ExecuteTest(repeat_);
auto out = scope->GetTensor(tensor_out->id);
return out;
Expand Down
Loading

0 comments on commit de53fce

Please sign in to comment.