Skip to content

Commit

Permalink
style: format the code
Browse files Browse the repository at this point in the history
  • Loading branch information
chenzhuofu committed Jan 16, 2025
1 parent d77866a commit 798f1ec
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 23 deletions.
6 changes: 4 additions & 2 deletions lib/kernels/test/src/test_concat_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ TEST_SUITE(FF_TEST_SUITE) {
allocator.allocate_tensor(output_shape);

Kernels::Concat::forward_kernel(managed_stream.raw_stream(),
output_accessor, input_accessors,
output_accessor,
input_accessors,
concat_axis);

std::vector<float> host_output_data =
Expand All @@ -49,7 +50,8 @@ TEST_SUITE(FF_TEST_SUITE) {
});
Kernels::Concat::backward_kernel(managed_stream.raw_stream(),
output_grad_accessor,
input_grad_accessors, concat_axis);
input_grad_accessors,
concat_axis);
}
}
}
5 changes: 3 additions & 2 deletions lib/kernels/test/src/test_dropout.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,9 @@ TEST_SUITE(FF_TEST_SUITE) {
DropoutPerDeviceState state = Kernels::Dropout::init_kernel(
managed_handle.raw_handle(), dropout_rate, seed, shape, allocator);

auto get_zero_count = [](const std::vector<float>& data) {
return std::count_if(data.begin(), data.end(), [](float x) { return x == 0.0f; });
auto get_zero_count = [](std::vector<float> const &data) {
return std::count_if(
data.begin(), data.end(), [](float x) { return x == 0.0f; });
};

SUBCASE("forward_kernel") {
Expand Down
38 changes: 19 additions & 19 deletions lib/kernels/test/src/test_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@
#include "kernels/managed_ff_stream.h"
#include "kernels/managed_per_device_ff_handle.h"
#include <doctest/doctest.h>
#include <vector>
#include <string>
#include <sstream>
#include <random>
#include <sstream>
#include <string>
#include <vector>

using namespace FlexFlow;

Expand Down Expand Up @@ -53,28 +53,28 @@ bool contains_non_zero(std::vector<T> &data) {
}

template <typename T, typename Func>
std::vector<T> repeat(std::size_t n, Func&& func) {
std::vector<T> result;
// result.reserve(n); // Sometimes we don't have default constructor for T
for (std::size_t i = 0; i < n; ++i) {
result.push_back(func());
}
return result;
std::vector<T> repeat(std::size_t n, Func &&func) {
std::vector<T> result;
// result.reserve(n); // Sometimes we don't have default constructor for T
for (std::size_t i = 0; i < n; ++i) {
result.push_back(func());
}
return result;
}

// Specialize doctest's StringMaker for std::vector<float>
template <>
struct doctest::StringMaker<std::vector<float>> {
static doctest::String convert(const std::vector<float>& vec) {
std::ostringstream oss;
for (size_t i = 0; i < vec.size(); ++i) {
oss << vec[i];
if (i != vec.size() - 1) {
oss << ", ";
}
}
return doctest::String(("[" + oss.str() + "]").c_str());
static doctest::String convert(std::vector<float> const &vec) {
std::ostringstream oss;
for (size_t i = 0; i < vec.size(); ++i) {
oss << vec[i];
if (i != vec.size() - 1) {
oss << ", ";
}
}
return doctest::String(("[" + oss.str() + "]").c_str());
}
};

#endif

0 comments on commit 798f1ec

Please sign in to comment.