Crocoddyl
Loading...
Searching...
No Matches
conversions.hpp
1
2// BSD 3-Clause License
3//
4// Copyright (C) 2024-2025, Heriot-Watt University
5// Copyright note valid unless otherwise stated in individual files.
6// All rights reserved.
8
9#ifndef CROCODDYL_UTILS_CONVERSIONS_HPP_
10#define CROCODDYL_UTILS_CONVERSIONS_HPP_
11
12#include <memory>
13#include <vector>
14
15#ifdef CROCODDYL_WITH_CODEGEN
16#include <cppad/cg/support/cppadcg_eigen.hpp>
17#include <cppad/cppad.hpp>
18#endif
19
20#include "crocoddyl/core/mathbase.hpp"
21
22namespace crocoddyl {
23
24template <typename Scalar>
26 typedef typename std::conditional<std::is_floating_point<Scalar>::value,
27 Scalar, double>::type type;
28};
29
30// Casting between floating-point types
31template <typename NewScalar, typename Scalar>
32static typename std::enable_if<std::is_floating_point<NewScalar>::value &&
33 std::is_floating_point<Scalar>::value,
34 NewScalar>::type
35scalar_cast(const Scalar& x) {
36 return static_cast<NewScalar>(x);
37}
38
39template <typename NewScalar, typename Scalar,
40 template <typename> class ItemTpl>
41std::vector<ItemTpl<NewScalar>> vector_cast(
42 const std::vector<ItemTpl<Scalar>>& in) {
43 std::vector<ItemTpl<NewScalar>> out;
44 out.reserve(in.size()); // Optimize allocation
45 for (const auto& obj : in) {
46 out.push_back(obj.template cast<NewScalar>());
47 }
48 return out;
49}
50
51template <typename NewScalar, typename Scalar,
52 template <typename> class ItemTpl>
53std::vector<std::shared_ptr<ItemTpl<NewScalar>>> vector_cast(
54 const std::vector<std::shared_ptr<ItemTpl<Scalar>>>& in) {
55 std::vector<std::shared_ptr<ItemTpl<NewScalar>>> out;
56 out.reserve(in.size()); // Optimize allocation
57 for (const auto& obj : in) {
58 out.push_back(std::static_pointer_cast<ItemTpl<NewScalar>>(
59 obj->template cast<NewScalar>()));
60 }
61 return out;
62}
63
64} // namespace crocoddyl
65
66#ifdef CROCODDYL_WITH_CODEGEN
67
68// Specialize Eigen's internal cast_impl for your specific types
69namespace Eigen {
70namespace internal {
71
72template <>
73struct cast_impl<CppAD::AD<CppAD::cg::CG<double>>, float> {
74 EIGEN_DEVICE_FUNC static inline float run(
75 const CppAD::AD<CppAD::cg::CG<double>>& x) {
76 // Perform the conversion. This example extracts the value from the AD type.
77 // You might need to adjust this depending on the specific implementation of
78 // CppAD::cg::CG<double>.
79 return static_cast<float>(CppAD::Value(x).getValue());
80 }
81};
82
83template <>
84struct cast_impl<CppAD::AD<CppAD::cg::CG<double>>, double> {
85 EIGEN_DEVICE_FUNC static inline double run(
86 const CppAD::AD<CppAD::cg::CG<double>>& x) {
87 return CppAD::Value(x).getValue();
88 }
89};
90
91template <>
92struct cast_impl<CppAD::AD<CppAD::cg::CG<float>>, float> {
93 EIGEN_DEVICE_FUNC static inline float run(
94 const CppAD::AD<CppAD::cg::CG<float>>& x) {
95 return CppAD::Value(x).getValue();
96 }
97};
98
99template <>
100struct cast_impl<CppAD::AD<CppAD::cg::CG<float>>, double> {
101 EIGEN_DEVICE_FUNC static inline double run(
102 const CppAD::AD<CppAD::cg::CG<float>>& x) {
103 // Perform the conversion. This example extracts the value from the AD type.
104 // You might need to adjust this depending on the specific implementation of
105 // CppAD::cg::CG<float>.
106 return static_cast<float>(CppAD::Value(x).getValue());
107 }
108};
109
110// Convert from CppAD::AD<CppAD::cg::CG<float>> to
111// CppAD::AD<CppAD::cg::CG<double>>
112template <>
113struct cast_impl<CppAD::AD<CppAD::cg::CG<float>>,
114 CppAD::AD<CppAD::cg::CG<double>>> {
115 EIGEN_DEVICE_FUNC static inline CppAD::AD<CppAD::cg::CG<double>> run(
116 const CppAD::AD<CppAD::cg::CG<float>>& x) {
117 return CppAD::AD<CppAD::cg::CG<double>>(
118 CppAD::cg::CG<double>(CppAD::Value(x).getValue()));
119 }
120};
121
122// Convert from CppAD::AD<CppAD::cg::CG<double>> to
123// CppAD::AD<CppAD::cg::CG<float>>
124template <>
125struct cast_impl<CppAD::AD<CppAD::cg::CG<double>>,
126 CppAD::AD<CppAD::cg::CG<float>>> {
127 EIGEN_DEVICE_FUNC static inline CppAD::AD<CppAD::cg::CG<float>> run(
128 const CppAD::AD<CppAD::cg::CG<double>>& x) {
129 return CppAD::AD<CppAD::cg::CG<float>>(
130 CppAD::cg::CG<float>(static_cast<float>(CppAD::Value(x).getValue())));
131 }
132};
133
134} // namespace internal
135} // namespace Eigen
136
137namespace crocoddyl {
138
139// Casting to CppAD types from floating-point types
140template <typename NewScalar, typename Scalar>
141static typename std::enable_if<
142 std::is_floating_point<Scalar>::value &&
143 (std::is_same<NewScalar, CppAD::AD<CppAD::cg::CG<double>>>::value ||
144 std::is_same<NewScalar, CppAD::AD<CppAD::cg::CG<float>>>::value),
145 NewScalar>::type
146scalar_cast(const Scalar& x) {
147 return static_cast<NewScalar>(x);
148}
149
150// Casting to floating-point types from CppAD types
151template <typename NewScalar, typename Scalar>
152static inline typename std::enable_if<std::is_floating_point<Scalar>::value,
153 NewScalar>::type
154scalar_cast(const CppAD::AD<CppAD::cg::CG<Scalar>>& x) {
155 return static_cast<NewScalar>(CppAD::Value(x).getValue());
156}
157
158} // namespace crocoddyl
159
160#endif
161
162#endif // CROCODDYL_UTILS_CONVERSIONS_HPP_