289 lines
8.7 KiB
C++
289 lines
8.7 KiB
C++
//
|
|
// Copyright (c) 2018-2019, Cem Bassoy, cem.bassoy@gmail.com
|
|
//
|
|
// Distributed under the Boost Software License, Version 1.0. (See
|
|
// accompanying file LICENSE_1_0.txt or copy at
|
|
// http://www.boost.org/LICENSE_1_0.txt)
|
|
//
|
|
// The authors gratefully acknowledge the support of
|
|
// Fraunhofer IOSB, Ettlingen, Germany
|
|
//
|
|
|
|
#ifndef _BOOST_UBLAS_TENSOR_EXPRESSIONS_EVALUATION_HPP_
|
|
#define _BOOST_UBLAS_TENSOR_EXPRESSIONS_EVALUATION_HPP_
|
|
|
|
#include <type_traits>
|
|
#include <stdexcept>
|
|
|
|
|
|
namespace boost::numeric::ublas {
|
|
|
|
template<class element_type, class storage_format, class storage_type>
|
|
class tensor;
|
|
|
|
template<class size_type>
|
|
class basic_extents;
|
|
|
|
}
|
|
|
|
namespace boost::numeric::ublas::detail {
|
|
|
|
template<class T, class D>
|
|
struct tensor_expression;
|
|
|
|
template<class T, class EL, class ER, class OP>
|
|
struct binary_tensor_expression;
|
|
|
|
template<class T, class E, class OP>
|
|
struct unary_tensor_expression;
|
|
|
|
}
|
|
|
|
namespace boost::numeric::ublas::detail {
|
|
|
|
template<class T, class E>
|
|
struct has_tensor_types
|
|
{ static constexpr bool value = false; };
|
|
|
|
template<class T>
|
|
struct has_tensor_types<T,T>
|
|
{ static constexpr bool value = true; };
|
|
|
|
template<class T, class D>
|
|
struct has_tensor_types<T, tensor_expression<T,D>>
|
|
{ static constexpr bool value = std::is_same<T,D>::value || has_tensor_types<T,D>::value; };
|
|
|
|
|
|
template<class T, class EL, class ER, class OP>
|
|
struct has_tensor_types<T, binary_tensor_expression<T,EL,ER,OP>>
|
|
{ static constexpr bool value = std::is_same<T,EL>::value || std::is_same<T,ER>::value || has_tensor_types<T,EL>::value || has_tensor_types<T,ER>::value; };
|
|
|
|
template<class T, class E, class OP>
|
|
struct has_tensor_types<T, unary_tensor_expression<T,E,OP>>
|
|
{ static constexpr bool value = std::is_same<T,E>::value || has_tensor_types<T,E>::value; };
|
|
|
|
} // namespace boost::numeric::ublas::detail
|
|
|
|
|
|
namespace boost::numeric::ublas::detail {
|
|
|
|
|
|
|
|
|
|
|
|
/** @brief Retrieves extents of the tensor
|
|
*
|
|
*/
|
|
template<class T, class F, class A>
|
|
auto retrieve_extents(tensor<T,F,A> const& t)
|
|
{
|
|
return t.extents();
|
|
}
|
|
|
|
/** @brief Retrieves extents of the tensor expression
|
|
*
|
|
* @note tensor expression must be a binary tree with at least one tensor type
|
|
*
|
|
* @returns extents of the child expression if it is a tensor or extents of one child of its child.
|
|
*/
|
|
template<class T, class D>
|
|
auto retrieve_extents(tensor_expression<T,D> const& expr)
|
|
{
|
|
static_assert(detail::has_tensor_types<T,tensor_expression<T,D>>::value,
|
|
"Error in boost::numeric::ublas::detail::retrieve_extents: Expression to evaluate should contain tensors.");
|
|
|
|
auto const& cast_expr = static_cast<D const&>(expr);
|
|
|
|
if constexpr ( std::is_same<T,D>::value )
|
|
return cast_expr.extents();
|
|
else
|
|
return retrieve_extents(cast_expr);
|
|
}
|
|
|
|
/** @brief Retrieves extents of the binary tensor expression
|
|
*
|
|
* @note tensor expression must be a binary tree with at least one tensor type
|
|
*
|
|
* @returns extents of the (left and if necessary then right) child expression if it is a tensor or extents of a child of its (left and if necessary then right) child.
|
|
*/
|
|
template<class T, class EL, class ER, class OP>
|
|
auto retrieve_extents(binary_tensor_expression<T,EL,ER,OP> const& expr)
|
|
{
|
|
static_assert(detail::has_tensor_types<T,binary_tensor_expression<T,EL,ER,OP>>::value,
|
|
"Error in boost::numeric::ublas::detail::retrieve_extents: Expression to evaluate should contain tensors.");
|
|
|
|
if constexpr ( std::is_same<T,EL>::value )
|
|
return expr.el.extents();
|
|
|
|
if constexpr ( std::is_same<T,ER>::value )
|
|
return expr.er.extents();
|
|
|
|
else if constexpr ( detail::has_tensor_types<T,EL>::value )
|
|
return retrieve_extents(expr.el);
|
|
|
|
else if constexpr ( detail::has_tensor_types<T,ER>::value )
|
|
return retrieve_extents(expr.er);
|
|
}
|
|
|
|
/** @brief Retrieves extents of the binary tensor expression
|
|
*
|
|
* @note tensor expression must be a binary tree with at least one tensor type
|
|
*
|
|
* @returns extents of the child expression if it is a tensor or extents of a child of its child.
|
|
*/
|
|
template<class T, class E, class OP>
|
|
auto retrieve_extents(unary_tensor_expression<T,E,OP> const& expr)
|
|
{
|
|
|
|
static_assert(detail::has_tensor_types<T,unary_tensor_expression<T,E,OP>>::value,
|
|
"Error in boost::numeric::ublas::detail::retrieve_extents: Expression to evaluate should contain tensors.");
|
|
|
|
if constexpr ( std::is_same<T,E>::value )
|
|
return expr.e.extents();
|
|
|
|
else if constexpr ( detail::has_tensor_types<T,E>::value )
|
|
return retrieve_extents(expr.e);
|
|
}
|
|
|
|
} // namespace boost::numeric::ublas::detail
|
|
|
|
|
|
///////////////
|
|
|
|
namespace boost::numeric::ublas::detail {
|
|
|
|
template<class T, class F, class A, class S>
|
|
auto all_extents_equal(tensor<T,F,A> const& t, basic_extents<S> const& extents)
|
|
{
|
|
return extents == t.extents();
|
|
}
|
|
|
|
template<class T, class D, class S>
|
|
auto all_extents_equal(tensor_expression<T,D> const& expr, basic_extents<S> const& extents)
|
|
{
|
|
static_assert(detail::has_tensor_types<T,tensor_expression<T,D>>::value,
|
|
"Error in boost::numeric::ublas::detail::all_extents_equal: Expression to evaluate should contain tensors.");
|
|
auto const& cast_expr = static_cast<D const&>(expr);
|
|
|
|
|
|
if constexpr ( std::is_same<T,D>::value )
|
|
if( extents != cast_expr.extents() )
|
|
return false;
|
|
|
|
if constexpr ( detail::has_tensor_types<T,D>::value )
|
|
if ( !all_extents_equal(cast_expr, extents))
|
|
return false;
|
|
|
|
return true;
|
|
|
|
}
|
|
|
|
template<class T, class EL, class ER, class OP, class S>
|
|
auto all_extents_equal(binary_tensor_expression<T,EL,ER,OP> const& expr, basic_extents<S> const& extents)
|
|
{
|
|
static_assert(detail::has_tensor_types<T,binary_tensor_expression<T,EL,ER,OP>>::value,
|
|
"Error in boost::numeric::ublas::detail::all_extents_equal: Expression to evaluate should contain tensors.");
|
|
|
|
if constexpr ( std::is_same<T,EL>::value )
|
|
if(extents != expr.el.extents())
|
|
return false;
|
|
|
|
if constexpr ( std::is_same<T,ER>::value )
|
|
if(extents != expr.er.extents())
|
|
return false;
|
|
|
|
if constexpr ( detail::has_tensor_types<T,EL>::value )
|
|
if(!all_extents_equal(expr.el, extents))
|
|
return false;
|
|
|
|
if constexpr ( detail::has_tensor_types<T,ER>::value )
|
|
if(!all_extents_equal(expr.er, extents))
|
|
return false;
|
|
|
|
return true;
|
|
}
|
|
|
|
|
|
template<class T, class E, class OP, class S>
|
|
auto all_extents_equal(unary_tensor_expression<T,E,OP> const& expr, basic_extents<S> const& extents)
|
|
{
|
|
|
|
static_assert(detail::has_tensor_types<T,unary_tensor_expression<T,E,OP>>::value,
|
|
"Error in boost::numeric::ublas::detail::all_extents_equal: Expression to evaluate should contain tensors.");
|
|
|
|
if constexpr ( std::is_same<T,E>::value )
|
|
if(extents != expr.e.extents())
|
|
return false;
|
|
|
|
if constexpr ( detail::has_tensor_types<T,E>::value )
|
|
if(!all_extents_equal(expr.e, extents))
|
|
return false;
|
|
|
|
return true;
|
|
}
|
|
|
|
} // namespace boost::numeric::ublas::detail
|
|
|
|
|
|
namespace boost::numeric::ublas::detail {
|
|
|
|
|
|
/** @brief Evaluates expression for a tensor
|
|
*
|
|
* Assigns the results of the expression to the tensor.
|
|
*
|
|
* \note Checks if shape of the tensor matches those of all tensors within the expression.
|
|
*/
|
|
template<class tensor_type, class derived_type>
|
|
void eval(tensor_type& lhs, tensor_expression<tensor_type, derived_type> const& expr)
|
|
{
|
|
if constexpr (detail::has_tensor_types<tensor_type, tensor_expression<tensor_type,derived_type> >::value )
|
|
if(!detail::all_extents_equal(expr, lhs.extents() ))
|
|
throw std::runtime_error("Error in boost::numeric::ublas::tensor: expression contains tensors with different shapes.");
|
|
|
|
#pragma omp parallel for
|
|
for(auto i = 0u; i < lhs.size(); ++i)
|
|
lhs(i) = expr()(i);
|
|
}
|
|
|
|
/** @brief Evaluates expression for a tensor
|
|
*
|
|
* Applies a unary function to the results of the expressions before the assignment.
|
|
* Usually applied needed for unary operators such as A += C;
|
|
*
|
|
* \note Checks if shape of the tensor matches those of all tensors within the expression.
|
|
*/
|
|
template<class tensor_type, class derived_type, class unary_fn>
|
|
void eval(tensor_type& lhs, tensor_expression<tensor_type, derived_type> const& expr, unary_fn const fn)
|
|
{
|
|
|
|
if constexpr (detail::has_tensor_types< tensor_type, tensor_expression<tensor_type,derived_type> >::value )
|
|
if(!detail::all_extents_equal( expr, lhs.extents() ))
|
|
throw std::runtime_error("Error in boost::numeric::ublas::tensor: expression contains tensors with different shapes.");
|
|
|
|
#pragma omp parallel for
|
|
for(auto i = 0u; i < lhs.size(); ++i)
|
|
fn(lhs(i), expr()(i));
|
|
}
|
|
|
|
|
|
|
|
/** @brief Evaluates expression for a tensor
|
|
*
|
|
* Applies a unary function to the results of the expressions before the assignment.
|
|
* Usually applied needed for unary operators such as A += C;
|
|
*
|
|
* \note Checks if shape of the tensor matches those of all tensors within the expression.
|
|
*/
|
|
template<class tensor_type, class unary_fn>
|
|
void eval(tensor_type& lhs, unary_fn const fn)
|
|
{
|
|
#pragma omp parallel for
|
|
for(auto i = 0u; i < lhs.size(); ++i)
|
|
fn(lhs(i));
|
|
}
|
|
|
|
|
|
}
|
|
#endif
|