33class ActivationModelQuadFlatLogTpl
34 :
public ActivationModelAbstractTpl<_Scalar> {
36 EIGEN_MAKE_ALIGNED_OPERATOR_NEW
39 typedef _Scalar Scalar;
41 typedef ActivationModelAbstractTpl<Scalar> Base;
44 typedef typename MathBase::VectorXs VectorXs;
45 typedef typename MathBase::MatrixXs MatrixXs;
56 explicit ActivationModelQuadFlatLogTpl(
const std::size_t nr,
57 const Scalar alpha = Scalar(1.))
58 : Base(nr), alpha_(alpha) {
59 if (alpha < Scalar(0.)) {
60 throw_pretty(
"Invalid argument: " <<
"alpha should be a positive value");
63 virtual ~ActivationModelQuadFlatLogTpl() =
default;
71 virtual void calc(
const std::shared_ptr<ActivationDataAbstract>& data,
72 const Eigen::Ref<const VectorXs>& r)
override {
73 if (
static_cast<std::size_t
>(r.size()) != nr_) {
75 "Invalid argument: " <<
"r has wrong dimension (it should be " +
76 std::to_string(nr_) +
")");
78 std::shared_ptr<Data> d = std::static_pointer_cast<Data>(data);
79 d->a0 = r.squaredNorm() / alpha_;
80 data->a_value = log(Scalar(1.0) + d->a0);
89 virtual void calcDiff(
const std::shared_ptr<ActivationDataAbstract>& data,
90 const Eigen::Ref<const VectorXs>& r)
override {
91 if (
static_cast<std::size_t
>(r.size()) != nr_) {
93 "Invalid argument: " <<
"r has wrong dimension (it should be " +
94 std::to_string(nr_) +
")");
96 std::shared_ptr<Data> d = std::static_pointer_cast<Data>(data);
98 d->a1 = Scalar(2.0) / (alpha_ + alpha_ * d->a0);
100 data->Arr.diagonal() = -d->a1 * d->a1 * r.array().square();
101 data->Arr.diagonal().array() += d->a1;
109 virtual std::shared_ptr<ActivationDataAbstract> createData()
override {
110 std::shared_ptr<Data> data =
111 std::allocate_shared<Data>(Eigen::aligned_allocator<Data>(),
this);
115 template <
typename NewScalar>
116 ActivationModelQuadFlatLogTpl<NewScalar> cast()
const {
117 typedef ActivationModelQuadFlatLogTpl<NewScalar> ReturnType;
118 ReturnType res(nr_, scalar_cast<NewScalar>(alpha_));
122 Scalar get_alpha()
const {
return alpha_; };
123 void set_alpha(
const Scalar alpha) { alpha_ = alpha; };
130 virtual void print(std::ostream& os)
const override {
131 os <<
"ActivationModelQuadFlatLog {nr=" << nr_ <<
", a=" << alpha_ <<
"}";
148struct ActivationDataQuadFlatLogTpl
149 :
public ActivationDataAbstractTpl<_Scalar> {
150 EIGEN_MAKE_ALIGNED_OPERATOR_NEW
152 typedef _Scalar Scalar;
154 typedef typename MathBase::VectorXs VectorXs;
155 typedef typename MathBase::DiagonalMatrixXs DiagonalMatrixXs;
156 typedef ActivationDataAbstractTpl<Scalar> Base;
158 template <
typename Activation>
159 explicit ActivationDataQuadFlatLogTpl(Activation*
const activation)
160 : Base(activation), a0(Scalar(0)), a1(Scalar(0)) {}
161 virtual ~ActivationDataQuadFlatLogTpl() =
default;