libcarla/include/system/boost/numeric/ublas/tensor/expression_evaluation.hpp
2024-10-18 13:19:59 +08:00

289 lines
8.7 KiB
C++

//
// Copyright (c) 2018-2019, Cem Bassoy, cem.bassoy@gmail.com
//
// Distributed under the Boost Software License, Version 1.0. (See
// accompanying file LICENSE_1_0.txt or copy at
// http://www.boost.org/LICENSE_1_0.txt)
//
// The authors gratefully acknowledge the support of
// Fraunhofer IOSB, Ettlingen, Germany
//
#ifndef _BOOST_UBLAS_TENSOR_EXPRESSIONS_EVALUATION_HPP_
#define _BOOST_UBLAS_TENSOR_EXPRESSIONS_EVALUATION_HPP_
#include <type_traits>
#include <stdexcept>
namespace boost::numeric::ublas {
template<class element_type, class storage_format, class storage_type>
class tensor;
template<class size_type>
class basic_extents;
}
namespace boost::numeric::ublas::detail {
template<class T, class D>
struct tensor_expression;
template<class T, class EL, class ER, class OP>
struct binary_tensor_expression;
template<class T, class E, class OP>
struct unary_tensor_expression;
}
namespace boost::numeric::ublas::detail {
template<class T, class E>
struct has_tensor_types
{ static constexpr bool value = false; };
template<class T>
struct has_tensor_types<T,T>
{ static constexpr bool value = true; };
template<class T, class D>
struct has_tensor_types<T, tensor_expression<T,D>>
{ static constexpr bool value = std::is_same<T,D>::value || has_tensor_types<T,D>::value; };
template<class T, class EL, class ER, class OP>
struct has_tensor_types<T, binary_tensor_expression<T,EL,ER,OP>>
{ static constexpr bool value = std::is_same<T,EL>::value || std::is_same<T,ER>::value || has_tensor_types<T,EL>::value || has_tensor_types<T,ER>::value; };
template<class T, class E, class OP>
struct has_tensor_types<T, unary_tensor_expression<T,E,OP>>
{ static constexpr bool value = std::is_same<T,E>::value || has_tensor_types<T,E>::value; };
} // namespace boost::numeric::ublas::detail
namespace boost::numeric::ublas::detail {
/** @brief Retrieves extents of the tensor
*
*/
template<class T, class F, class A>
auto retrieve_extents(tensor<T,F,A> const& t)
{
return t.extents();
}
/** @brief Retrieves extents of the tensor expression
*
* @note tensor expression must be a binary tree with at least one tensor type
*
* @returns extents of the child expression if it is a tensor or extents of one child of its child.
*/
template<class T, class D>
auto retrieve_extents(tensor_expression<T,D> const& expr)
{
static_assert(detail::has_tensor_types<T,tensor_expression<T,D>>::value,
"Error in boost::numeric::ublas::detail::retrieve_extents: Expression to evaluate should contain tensors.");
auto const& cast_expr = static_cast<D const&>(expr);
if constexpr ( std::is_same<T,D>::value )
return cast_expr.extents();
else
return retrieve_extents(cast_expr);
}
/** @brief Retrieves extents of the binary tensor expression
*
* @note tensor expression must be a binary tree with at least one tensor type
*
* @returns extents of the (left and if necessary then right) child expression if it is a tensor or extents of a child of its (left and if necessary then right) child.
*/
template<class T, class EL, class ER, class OP>
auto retrieve_extents(binary_tensor_expression<T,EL,ER,OP> const& expr)
{
static_assert(detail::has_tensor_types<T,binary_tensor_expression<T,EL,ER,OP>>::value,
"Error in boost::numeric::ublas::detail::retrieve_extents: Expression to evaluate should contain tensors.");
if constexpr ( std::is_same<T,EL>::value )
return expr.el.extents();
if constexpr ( std::is_same<T,ER>::value )
return expr.er.extents();
else if constexpr ( detail::has_tensor_types<T,EL>::value )
return retrieve_extents(expr.el);
else if constexpr ( detail::has_tensor_types<T,ER>::value )
return retrieve_extents(expr.er);
}
/** @brief Retrieves extents of the binary tensor expression
*
* @note tensor expression must be a binary tree with at least one tensor type
*
* @returns extents of the child expression if it is a tensor or extents of a child of its child.
*/
template<class T, class E, class OP>
auto retrieve_extents(unary_tensor_expression<T,E,OP> const& expr)
{
static_assert(detail::has_tensor_types<T,unary_tensor_expression<T,E,OP>>::value,
"Error in boost::numeric::ublas::detail::retrieve_extents: Expression to evaluate should contain tensors.");
if constexpr ( std::is_same<T,E>::value )
return expr.e.extents();
else if constexpr ( detail::has_tensor_types<T,E>::value )
return retrieve_extents(expr.e);
}
} // namespace boost::numeric::ublas::detail
///////////////
namespace boost::numeric::ublas::detail {
template<class T, class F, class A, class S>
auto all_extents_equal(tensor<T,F,A> const& t, basic_extents<S> const& extents)
{
return extents == t.extents();
}
template<class T, class D, class S>
auto all_extents_equal(tensor_expression<T,D> const& expr, basic_extents<S> const& extents)
{
static_assert(detail::has_tensor_types<T,tensor_expression<T,D>>::value,
"Error in boost::numeric::ublas::detail::all_extents_equal: Expression to evaluate should contain tensors.");
auto const& cast_expr = static_cast<D const&>(expr);
if constexpr ( std::is_same<T,D>::value )
if( extents != cast_expr.extents() )
return false;
if constexpr ( detail::has_tensor_types<T,D>::value )
if ( !all_extents_equal(cast_expr, extents))
return false;
return true;
}
template<class T, class EL, class ER, class OP, class S>
auto all_extents_equal(binary_tensor_expression<T,EL,ER,OP> const& expr, basic_extents<S> const& extents)
{
static_assert(detail::has_tensor_types<T,binary_tensor_expression<T,EL,ER,OP>>::value,
"Error in boost::numeric::ublas::detail::all_extents_equal: Expression to evaluate should contain tensors.");
if constexpr ( std::is_same<T,EL>::value )
if(extents != expr.el.extents())
return false;
if constexpr ( std::is_same<T,ER>::value )
if(extents != expr.er.extents())
return false;
if constexpr ( detail::has_tensor_types<T,EL>::value )
if(!all_extents_equal(expr.el, extents))
return false;
if constexpr ( detail::has_tensor_types<T,ER>::value )
if(!all_extents_equal(expr.er, extents))
return false;
return true;
}
template<class T, class E, class OP, class S>
auto all_extents_equal(unary_tensor_expression<T,E,OP> const& expr, basic_extents<S> const& extents)
{
static_assert(detail::has_tensor_types<T,unary_tensor_expression<T,E,OP>>::value,
"Error in boost::numeric::ublas::detail::all_extents_equal: Expression to evaluate should contain tensors.");
if constexpr ( std::is_same<T,E>::value )
if(extents != expr.e.extents())
return false;
if constexpr ( detail::has_tensor_types<T,E>::value )
if(!all_extents_equal(expr.e, extents))
return false;
return true;
}
} // namespace boost::numeric::ublas::detail
namespace boost::numeric::ublas::detail {
/** @brief Evaluates expression for a tensor
*
* Assigns the results of the expression to the tensor.
*
* \note Checks if shape of the tensor matches those of all tensors within the expression.
*/
template<class tensor_type, class derived_type>
void eval(tensor_type& lhs, tensor_expression<tensor_type, derived_type> const& expr)
{
if constexpr (detail::has_tensor_types<tensor_type, tensor_expression<tensor_type,derived_type> >::value )
if(!detail::all_extents_equal(expr, lhs.extents() ))
throw std::runtime_error("Error in boost::numeric::ublas::tensor: expression contains tensors with different shapes.");
#pragma omp parallel for
for(auto i = 0u; i < lhs.size(); ++i)
lhs(i) = expr()(i);
}
/** @brief Evaluates expression for a tensor
*
* Applies a unary function to the results of the expressions before the assignment.
* Usually applied needed for unary operators such as A += C;
*
* \note Checks if shape of the tensor matches those of all tensors within the expression.
*/
template<class tensor_type, class derived_type, class unary_fn>
void eval(tensor_type& lhs, tensor_expression<tensor_type, derived_type> const& expr, unary_fn const fn)
{
if constexpr (detail::has_tensor_types< tensor_type, tensor_expression<tensor_type,derived_type> >::value )
if(!detail::all_extents_equal( expr, lhs.extents() ))
throw std::runtime_error("Error in boost::numeric::ublas::tensor: expression contains tensors with different shapes.");
#pragma omp parallel for
for(auto i = 0u; i < lhs.size(); ++i)
fn(lhs(i), expr()(i));
}
/** @brief Evaluates expression for a tensor
*
* Applies a unary function to the results of the expressions before the assignment.
* Usually applied needed for unary operators such as A += C;
*
* \note Checks if shape of the tensor matches those of all tensors within the expression.
*/
template<class tensor_type, class unary_fn>
void eval(tensor_type& lhs, unary_fn const fn)
{
#pragma omp parallel for
for(auto i = 0u; i < lhs.size(); ++i)
fn(lhs(i));
}
}
#endif