Skip to content

Commit

Permalink
added the possibility to handle templates in R. started wrtinign Deri…
Browse files Browse the repository at this point in the history
…vative class in Cpp
  • Loading branch information
Konrad1991 committed Dec 18, 2023
1 parent 4c863d3 commit 6f6a97c
Show file tree
Hide file tree
Showing 4 changed files with 87 additions and 35 deletions.
37 changes: 37 additions & 0 deletions R/printAST.R
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,38 @@ printClass <- R6::R6Class("printClass",
print(paste(c(times, obj), collapse = " "), collapse = "")
}
}
},

customAsCall = function(code) {
if(as.character(code[[1]]) == "diamond") {
stopifnot(length(code) == 3)
return(paste0(code[[2]], "<", code[[3]], ">"))
} else {
return(as.call(code))
}
},

getSTRING = function(code) {
for(i in seq_along(1:length(code))) {
if(is.list(code[[i]])) {
code[[i]] <- self$getSTRING(code[[i]])
}
}
code <- self$customAsCall(code)
return(code)
},

removeBackslashes = function(str) {
str <- deparse(str)
str <- paste("r{", str, "}", collapse = "")
print(str)
str <- gsub("\"", "", str)
return(substr(str, 3, nchar(str) - 1))
},

Call2STRING = function(code) {
code <- self$getSTRING(code)
return(self$removeBackslashes(code))
}

)
Expand All @@ -43,3 +75,8 @@ printClass <- R6::R6Class("printClass",
#ast <- pC$getAST(quote(a*b + c*d/h -sin(f)) )
#pC$getCALLS(ast)
#pC$printAST(ast)

#ast <- pC$getAST(quote(diamond(a, 12)*b + c*d + diamond(b, 5) * diamond(h, diamond(k, j)) ) )
#code <- pC$Call2STRING(ast)
#code

59 changes: 35 additions & 24 deletions inst/include/etr_new/Derivs.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -271,43 +271,52 @@ inline constexpr BinaryType<LDeriv, RDeriv, Trait, OpTrait> produceBinaryType()
return BinaryType<LDeriv, RDeriv, Trait, OpTrait>();
}

template<typename T, int Idx>

template<typename T>
struct VariableType {
using Type = T;
static constexpr int Index = Idx;
};
template<typename T>
inline constexpr VariableType<T> produceVariableType() {
return VariableType<T>();
}

template <typename T, int Idx>
inline constexpr auto CalcIncrement() {
if constexpr(IsVariable<T>) {
return 1;
} else {
//return 0;
return CalcIncrement<typename T::typeTraitL, Idx>() + CalcIncrement<typename T::typeTraitR, Idx>();
template<int N>
struct Derivative {
std::array<Buffer<BaseType>*, N> variables;
std::array<Buffer<BaseType>*, N> derivatives;
Derivative(std::initializer_list<Buffer<BaseType>*>& variables_) {
size_t i = 0;
for(auto& elem: variables_) {
variables[i] = elem;
derivatives[i] -> resize(elem -> size());
i++;
}
}
}
};

template<typename T, int Idx>

template<typename T>
requires IsVariable<T>
inline constexpr auto walkT() -> VariableType<T, Idx> {
return VariableType<T, Idx>();
inline constexpr auto walkT() -> VariableType<T> {
return VariableType<T>();
}

template <typename T, int Idx>
template <typename T>
requires IsMultiplication<T>
inline constexpr auto walkT() {
return produceQuarternyType< decltype( getL<typename T::typeTraitL>() ),
decltype( getR<typename T::typeTraitR>() ),
decltype( walkT<typename T::typeTraitL, Idx + CalcIncrement<typename T::typeTraitL, Idx>()>() ),
decltype( walkT<typename T::typeTraitR, Idx + CalcIncrement<typename T::typeTraitR, Idx>()>() ),
return produceQuarternyType< decltype( produceVariableType< typename T::typeTraitL >() ),
decltype( produceVariableType< typename T::typeTraitR >() ),
decltype( walkT<typename T::typeTraitL>() ),
decltype( walkT<typename T::typeTraitR>() ),
QuarternaryTrait, TimesDerivTrait>();
}

template <typename T, int Idx>
template <typename T>
requires IsAddition<T>
inline constexpr auto walkT() {
return produceBinaryType< decltype( walkT<typename T::typeTraitL, Idx + CalcIncrement<typename T::typeTraitL, Idx>()>() ),
decltype( walkT<typename T::typeTraitR, Idx + CalcIncrement<typename T::typeTraitR, Idx>()>() ),
return produceBinaryType< decltype( walkT<typename T::typeTraitL>() ),
decltype( walkT<typename T::typeTraitR>() ),
QuarternaryTrait, PlusDerivTrait>();
}

Expand All @@ -316,10 +325,11 @@ requires (IsVec<T> && !IsVariable<T>)
inline void walkT() {
using tD = ExtractedTypeD<T>;
printTAST<tD>();
constexpr int Idx = 0;
constexpr auto res = walkT<tD, Idx>();
constexpr auto res = walkT<tD>();
std::cout << "final result " << std::endl;
printTAST<decltype(res)>();


return;
}

Expand All @@ -333,6 +343,7 @@ inline void walkT() {
/*
traverse tree and modify variables --> struct {variable, int}
*/
/*
template <typename I, typename f, typename Trait, typename CTrait>
struct UnaryType {
using TypeTrait = Trait;
Expand Down Expand Up @@ -414,7 +425,7 @@ inline constexpr auto addNumbersToVars() {
printTAST<decltype(res)>();
return;
}

*/


/*
Expand Down
25 changes: 15 additions & 10 deletions inst/include/etr_new/UtilsTraits.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -164,8 +164,21 @@ struct UnaryTrait { using RetType = BaseType; };
struct BinaryTrait { using RetType = BaseType; };
struct QuarternaryTrait { using RetType = BaseType; };

struct PlusDerivTrait { using RetType = BaseType; };
struct TimesDerivTrait { using RetType = BaseType; };
typedef double (*quaternaryFct)(double, double, double, double);
inline double TimesDeriv(double l, double r, double lDeriv, double rDeriv) {
return l*rDeriv + r*lDeriv;
}
inline double PlusDeriv(double lDeriv, double rDeriv) {
return lDeriv + rDeriv;
}

struct PlusDerivTrait {
using RetType = BaseType;
};

struct TimesDerivTrait {
using RetType = BaseType;
};
struct SinusDerivTrait { using RetType = BaseType; };

struct PlusTrait { using RetType = BaseType; };
Expand Down Expand Up @@ -406,14 +419,6 @@ inline double Log(double obj) { return log(obj); }
inline double SquareRoot(double obj) { return sqrt(obj); }
inline double MinusUnary(double obj) { return -obj; }

typedef double (*quaternaryFct)(double, double, double, double);
inline double TimesDeriv(double l, double r, double lDeriv, double rDeriv) {
return l*rDeriv + r*lDeriv;
}
inline double PlusDeriv(double lDeriv, double rDeriv) {
return lDeriv + rDeriv;
}

inline double Equal(double a, double b) {
if(fabs(a - b) < 1E-3) {
return 1.0;
Expand Down
1 change: 0 additions & 1 deletion src/Tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
SEXP getXPtrBorrowTest();

sexp testall(BorrowPtr a, double type_test) {

sexp size;
sexp temp;
sexp ret;
Expand Down

0 comments on commit 6f6a97c

Please sign in to comment.