Skip to content

Commit

Permalink
draft for forward AD at compile time. Not finished yet
Browse files Browse the repository at this point in the history
  • Loading branch information
Konrad1991 committed Dec 14, 2023
1 parent 3ae2226 commit 5043d54
Show file tree
Hide file tree
Showing 4 changed files with 341 additions and 20 deletions.
57 changes: 57 additions & 0 deletions R/DerivHelper.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@

TreeInformation <- R6::R6Class(
public = list(
variableList = c(),
mapName = "",

initialize = function(variableList, mapName) {
self$variableList = variableList
self$mapName = mapName
},

isVariable = function(variable) {
if(c(deparse(variable)) %in% self$variableList) {
return(match(c(deparse(variable)), self$variableList))
} else {
return(-1)
}
},

getAst = function(code) {
if(!is.call(code)) {
return(code)
}

code <- as.list(code)
if(code[[1]] != as.name("VarPointer")) {
for(i in seq_along(1:length(code))) {
Idx = self$isVariable(code[[i]])
if(Idx != -1) {
code[[i]] <- str2lang(paste("VarPointer(", self$mapName, ",", Idx, ");"))
}
}
}


lapply(code, function(x) {
self$getAst(x)
})
},

getCalls = function(code) {
out <- purrr::map_if(code, is.list, self$getCalls)
out <- as.call(out)
return(out)
},

convert = function(code) {
modCode = self$getAst(code)
modCode = self$getCalls(modCode)
return(modCode)
}

)
)
#code <- quote(a*b[1] + b*b*a)
#ti <- TreeInformation$new(c("a", "b"), "Map" )
#ti$convert(code)
266 changes: 247 additions & 19 deletions inst/include/etr_new/Derivs.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,19 @@
#include "unaryCalculations.hpp"
#include "binaryCalculations.hpp"

#include <chrono>

/*
Assume v.l and v.r are of type Buffer.
Than lDeriv and rDeriv are of type Vec<Buffer>.
Thus, l, r, lDeriv and rDeriv are objects which exists.
Here, it is possible to use pointers to them instead of copying.
if e.g. v.l is of type Operation I will move the object.
Or use dfdr and pass all expression created in R?
--> Does not work as variables are removed which store old derivs
*/

namespace etr {

template <typename L, typename R, binaryFct f, typename Trait,
Expand Down Expand Up @@ -76,8 +89,9 @@ struct QuaternaryOperation {
using TypeTrait = Trait;
using CaseTrait = QuarternaryTrait;
using RetType = typename CTrait::RetType;
const L l;
const R r;

const L& l;
const R& r;
const LD lDeriv;
const RD rDeriv;
MatrixParameter mp;
Expand All @@ -100,7 +114,7 @@ struct QuaternaryOperation {
lDeriv(lDeriv_), rDeriv(rDeriv_), mp(mp_) {
ass(l.size() == lDeriv.size(), "l size and lDeriv size do not match"); ass(r.size() == rDeriv.size(), "r size and rDeriv size do not match");
}

template <typename LType, typename RType, typename LDType, typename RDType, quaternaryFct fOther, typename TraitOther, typename CTraitOther>
QuaternaryOperation(const QuaternaryOperation<LType, RType, LDType, RDType, fOther, TraitOther, CTraitOther> &other) :
l(other.l), r(other.r), lDeriv(other.lDeriv), rDeriv(other.rDeriv), mp(other.mp) {
Expand All @@ -111,9 +125,9 @@ struct QuaternaryOperation {
constexpr bool isDoubleL = std::is_same_v<L, double>;
constexpr bool isDoubleR = std::is_same_v<R, double>;
if constexpr (isDoubleL && isDoubleR) {
return f(l, r, lDeriv, *rDeriv);
return f(l, r, lDeriv, rDeriv);
} else if constexpr (!isDoubleL && isDoubleR) {
return f(l[i % l.size()], r, lDeriv[i % lDeriv.size()], rDeriv);
return f(l[i % l.size()], r, lDeriv[i % lDeriv.size()], rDeriv);
} else if constexpr (isDoubleL && !isDoubleR) {
return f(l, r[i % r.size()], lDeriv, rDeriv[i % rDeriv.size()]);
} else if constexpr (!isDoubleL && !isDoubleR) {
Expand Down Expand Up @@ -149,31 +163,28 @@ template<typename T>
requires (IsVariable<T> && !IsVec<T>)
inline Vec<BaseType> walk(T& v,
std::unordered_map<const etr::Buffer<etr::BaseType>*, etr::Vec<etr::BaseType>>& varDerivs) {
return varDerivs[&v];
try {
return varDerivs.at(&v);
} catch(const std::out_of_range& ex) {
return Vec<BaseType>(v.size());
}
}

template<typename T>
requires IsMultiplication<T>
inline auto walk(T& v,
std::unordered_map<const etr::Buffer<etr::BaseType>*, etr::Vec<etr::BaseType>>& varDerivs) ->
QuaternaryOperation<decltype(v.l), decltype(v.r),
decltype(walk(v.l, varDerivs).d),
decltype(walk(v.r, varDerivs).d),
decltype(walk(v.l, varDerivs)),
decltype(walk(v.r, varDerivs)),
TimesDeriv, TimesDerivTrait, QuarternaryTrait> {
auto lDeriv = walk(v.l, varDerivs);
auto rDeriv = walk(v.r, varDerivs);
/*
Assume v.l and v.r are of type Buffer.
Than lDeriv and rDeriv are of type Vec<Buffer>.
Thus, l, r, lDeriv and rDeriv are objects which exists.
Here, it is possible to use pointers to them instead of copying.
Do I have to copy v.l if it is of type Operation?
*/
MatrixParameter mp; defineMatrix(v.l, v.r, mp);
return QuaternaryOperation<decltype(v.l), decltype(v.r),
decltype(lDeriv.d), decltype(rDeriv.d),
decltype(lDeriv), decltype(rDeriv),
TimesDeriv, TimesDerivTrait, QuarternaryTrait>(
v.l, v.r, lDeriv.d, rDeriv.d, mp);
std::move(v.l), std::move(v.r), std::move(lDeriv), std::move(rDeriv), mp);
}

template<typename T>
Expand All @@ -189,22 +200,239 @@ inline auto walk(T& v,
auto rDeriv = walk(v.r, varDerivs);
MatrixParameter mp; defineMatrix(v.l, v.r, mp);
return BinaryOperationDeriv<decltype(lDeriv), decltype(rDeriv),
PlusDeriv, PlusDerivTrait>(lDeriv, rDeriv, mp);
PlusDeriv, PlusDerivTrait>(std::move(lDeriv), std::move(rDeriv), mp);
}

template<typename T>
requires (IsVec<T> && !IsVariable<T>)
inline void walk(T&& v, Vec<BaseType>& result,
std::unordered_map<const etr::Buffer<etr::BaseType>*, etr::Vec<etr::BaseType>>& varDerivs) {

auto start_time1 = std::chrono::high_resolution_clock::now();
auto res = walk(v.d, varDerivs);
auto end_time1 = std::chrono::high_resolution_clock::now();
auto duration1 = std::chrono::duration_cast<std::chrono::nanoseconds>(end_time1 - start_time1).count();
std::cout << "Time taken: " << duration1 << " nanoseconds" << std::endl;

auto start_time2 = std::chrono::high_resolution_clock::now();
result.resize(res.size());
for(size_t i = 0; i < res.size(); i++) {
std::cout << res[i] << std::endl;
result.d.p[i] = res[i];
}
auto end_time2 = std::chrono::high_resolution_clock::now();
auto duration2 = std::chrono::duration_cast<std::chrono::nanoseconds>(end_time2 - start_time2).count();
std::cout << "Time taken: " << duration2 << " nanoseconds" << std::endl;
printAST(res);
}






template <typename T>
struct ExtractTypeD;
template <typename T, typename R, typename Trait>
struct ExtractTypeD<Vec<T, R, Trait>> {
using type = R;
};
template <typename T>
using ExtractedTypeD = typename ExtractTypeD<T>::type;

template<typename T>
constexpr T getL() { return T(); }

template<typename T>
constexpr T getR() { return T(); }

template<typename L, typename R, typename LDeriv, typename RDeriv, typename Trait, typename OpTrait>
struct QuarternyType {
using typeTraitL = L;
using typeTraitR = R;
using typeTraitLDeriv = LDeriv;
using typeTraitRDeriv = RDeriv;
using TypeTrait = Trait;
using VarTrait = std::tuple<L, R, RDeriv, RDeriv>;
};
template<typename L, typename R, typename LDeriv, typename RDeriv, typename Trait, typename OpTrait>
inline constexpr QuarternyType <L, R, LDeriv, RDeriv, Trait, OpTrait> produceQuarternyType() {
return QuarternyType<L, R, LDeriv, RDeriv, Trait, OpTrait>();
}

template<typename LDeriv, typename RDeriv, typename Trait, typename OpTrait>
struct BinaryType {
using typeTraitLDeriv = LDeriv;
using typeTraitRDeriv = RDeriv;
using TypeTrait = Trait;
using VarTrait = std::tuple<LDeriv, RDeriv>;
};
template<typename LDeriv, typename RDeriv, typename Trait, typename OpTrait>
inline constexpr BinaryType<LDeriv, RDeriv, Trait, OpTrait> produceBinaryType() {
return BinaryType<LDeriv, RDeriv, Trait, OpTrait>();
}

template<typename T, int Idx>
struct VariableType {
using Type = T;
static constexpr int Index = Idx;
};

template <typename T, int Idx>
inline constexpr auto CalcIncrement() {
if constexpr(IsVariable<T>) {
return 1;
} else {
//return 0;
return CalcIncrement<typename T::typeTraitL, Idx>() + CalcIncrement<typename T::typeTraitR, Idx>();
}
}

template<typename T, int Idx>
requires IsVariable<T>
inline constexpr auto walkT() -> VariableType<T, Idx> {
return VariableType<T, Idx>();
}

template <typename T, int Idx>
requires IsMultiplication<T>
inline constexpr auto walkT() {
return produceQuarternyType< decltype( getL<typename T::typeTraitL>() ),
decltype( getR<typename T::typeTraitR>() ),
decltype( walkT<typename T::typeTraitL, Idx + CalcIncrement<typename T::typeTraitL, Idx>()>() ),
decltype( walkT<typename T::typeTraitR, Idx + CalcIncrement<typename T::typeTraitR, Idx>()>() ),
QuarternaryTrait, TimesDerivTrait>();
}

template <typename T, int Idx>
requires IsAddition<T>
inline constexpr auto walkT() {
return produceBinaryType< decltype( walkT<typename T::typeTraitL, Idx + CalcIncrement<typename T::typeTraitL, Idx>()>() ),
decltype( walkT<typename T::typeTraitR, Idx + CalcIncrement<typename T::typeTraitR, Idx>()>() ),
QuarternaryTrait, PlusDerivTrait>();
}

template<typename T>
requires (IsVec<T> && !IsVariable<T>)
inline void walkT() {
using tD = ExtractedTypeD<T>;
printTAST<tD>();
constexpr int Idx = 0;
constexpr auto res = walkT<tD, Idx>();
std::cout << "final result " << std::endl;
printTAST<decltype(res)>();
return;
}








/*
traverse tree and modify variables --> struct {variable, int}
*/
template <typename I, typename f, typename Trait, typename CTrait>
struct UnaryType {
using TypeTrait = Trait;
using VarTrait = std::tuple<I>;
};

template <typename I, typename f, typename Trait, typename CTrait>
inline constexpr UnaryType<I, f, Trait, CTrait> produceUnaryType() {
return UnaryType<I, f, Trait, CTrait>();
}

template <typename T, int Idx>
inline constexpr auto CalcIncrementNumber() {
if constexpr(IsVariable<T>) {
return 1;
} else if constexpr(std::is_same_v<typename T::CaseTrait, UnaryTrait>) {
return CalcIncrementNumber<typename T::typeTraitObj, Idx>();
} else {
return CalcIncrementNumber<typename T::typeTraitL, Idx>() + CalcIncrementNumber<typename T::typeTraitR, Idx>();
}
}

// issue: predeclare also the walkT functions

template<typename T, int Idx>
requires IsVariable<T>
inline constexpr auto walkAddNumbers();

template<typename T, int Idx>
requires IsMultiplication<T>
inline constexpr auto walkAddNumbers();

template<typename T, int Idx>
requires IsAddition<T>
inline constexpr auto walkAddNumbers();

template<typename T, int Idx>
requires IsSinus<T>
inline constexpr auto walkAddNumbers();



template<typename T, int Idx>
requires IsVariable<T>
inline constexpr auto walkAddNumbers() {
return VariableType<T, Idx>();
}

template<typename T, int Idx>
requires IsMultiplication<T>
inline constexpr auto walkAddNumbers() {
return produceBinaryType< decltype( walkAddNumbers<typename T::typeTraitL, Idx + CalcIncrementNumber<typename T::typeTraitL, Idx>()>() ),
decltype( walkAddNumbers<typename T::typeTraitR, Idx + CalcIncrementNumber<typename T::typeTraitR, Idx>() + 1>() ),
BinaryTrait, TimesDerivTrait>();
}

template<typename T, int Idx>
requires IsAddition<T>
inline constexpr auto walkAddNumbers() {
return produceBinaryType< decltype( walkAddNumbers<typename T::typeTraitL, Idx + CalcIncrementNumber<typename T::typeTraitL, Idx>()>() ),
decltype( walkAddNumbers<typename T::typeTraitR, Idx + CalcIncrementNumber<typename T::typeTraitR, Idx>() + 1>() ),
BinaryTrait, PlusDerivTrait>();
}

template<typename T, int Idx>
requires IsSinus<T>
inline constexpr auto walkAddNumbers() {
return produceUnaryType< decltype( walkAddNumbers<typename T::typeTraitObj, Idx + CalcIncrementNumber<typename T::typeTraitObj, Idx>()>() ),
SinusDerivTrait,
UnaryTrait, SinusTrait>();
}


template<typename T>
inline constexpr auto addNumbersToVars() {
using tD = ExtractedTypeD<T>;
printTAST<tD>();
constexpr auto res = walkAddNumbers<tD, 0>();
printTAST<decltype(res)>();
return;
}



/*
Plan:
1. change variables to struct {variable, int} in original tree. Difficult.
Do it in R.:
> code <- quote(a*b[1] + b*b*a)
> ti <- TreeInformation$new(c("a", "b"), "Map" )
> ti$convert(code)
VarPointer(Map, 1) * VarPointer(Map, 2)[1] + VarPointer(Map, 2) * VarPointer(Map, 2) * VarPointer(Map, 1)
or work by reference. --> needs constructors
2. convert tree to derivative tree
3. Eval tree with lookup
*/




} // namespace etr

#endif
Loading

0 comments on commit 5043d54

Please sign in to comment.