2008-09-18

Expression Templateを解説しようとしたら収拾がつかなくなった

Expression Templateの検索クエリで飛んでくる人が多い。ここはひとつ、ETの簡単な説明をビシッと書いておこうと思い立った。それには、具体的なサンプルコードが必要だ。一番分かりやすいのは、配列同士の各要素への演算を、Lazyに行うことだろう。

えーと、各要素の足し引きができればいいかな。おっと、スカラー倍できるとなおいいな。いやまて、int型を直接使うのは面倒だ。いやなに、"We can solve any problem by introducing an extra level of indirection."という言葉がある。問題ない。さて、最後の仕上げに、遅らせてきた演算を実行、ってちょとまて、要素の数が分からないといけないな。コンパイル時にできるかな。いや、やっぱりいらんな。

なんてことをしていたら、サンプルコードが複雑になりすぎた。どうしてくれようこの長ったらしいコード。

#define USE_EXPRESSION_TEMPLATE


template < size_t N >
class Array
{
public :
    Array()
    {
        ptr = new int[N] ;
    }
    
    Array( Array const & array )
    {
        ptr = new int[N] ;
        memcpy( ptr, array.ptr, sizeof(int) * N ) ;
    }
    
    template < typename EXPR >
    Array( EXPR const & expr )
    {
        ptr = new int[ N ] ;
        for ( size_t i = 0 ; i != N ; ++i )
        ptr[i] = expr[i] ;
    }
    
    
#ifdef COMPILER_SUPPORT_RVALUE_REFERENCE
    Array( Array && array )
    {
        ptr = array.ptr ;
        array.ptr = nullptr ;
    }
#endif // COMPILER_SUPPORT_RVALUE_REFERENCE

    template < typename EXPR >
    Array & operator = ( EXPR const & expr )
    {
        ptr = new int[ N ] ;
        for ( size_t i = 0 ; i != N ; ++i )
        ptr[i] = expr[i] ;
        
        return *this ;
    }
    
    Array & operator = ( Array const & r )
    {
        ptr = new int[N] ;
        memcpy( ptr, r.ptr, sizeof(int) * N ) ;
        return *this ;        
    }
    
    ~Array()
    {
        delete[] ptr ;
    }
    
    int operator [] (size_t i) const
    {
        return ptr[i] ;
    }
    
    int & operator [] (size_t i)
    {
        return ptr[i] ;
    }

public :
    static const size_t size = N ;
    
private :
    int * ptr ;
} ;



#ifndef USE_EXPRESSION_TEMPLATE
// ordinary implementation.
template < size_t N >
Array<N> operator + ( Array<N> const & l, Array<N> const & r )
{
    Array<N> result ;
    
    for ( int i = 0 ; i != N ; ++i )
        result[i] = l[i] + r[i] ;

    return result ;
}

template < size_t N >
Array<N> operator - ( Array<N> const & l, Array<N> const & r )
{
    Array<N> result ;
    
    for ( int i = 0 ; i != N ; ++i )
        result[i] = l[i] - r[i] ;

    return result ;
}

template < size_t N >
Array<N> operator * ( int x , Array<N> const & r)
{
    Array<N> result ;
    
    for ( int i = 0 ; i != N ; ++i )
        result[i] = x * r[i] ;

    return result ;
}

template < size_t N >
Array<N> operator * ( Array<N> const & l, int x)
{
    Array<N> result ;
    
    for ( int i = 0 ; i != N ; ++i )
        result[i] = x * l[i] ;

    return result ;
}

#else USE_EXPRESSION_TEMPLATE
// Expression Template

struct plus
{
    static int apply( int a, int b )
    { return a + b ; }

} ;

struct minus
{
    static int apply( int a, int b )
    { return a - b ; }
} ;

struct mul
{
    
    static int apply( int a, int b )
    { return a * b ; }
} ;


template < typename L, typename Op, typename R >
class Expression
{
public :
    Expression( L const & l, R const & r )
        : l(l), r(r) { }
        
    int operator [] ( size_t i ) const
    {
        return Op::apply( l[i], r[i] ) ;
    }
        

private :
    L const & l ;
    R const & r ;
} ;

template < typename L, typename R >
Expression< L, plus, R > operator + (L const & l, R const & r)
{
    return Expression< L, plus, R >(l, r) ;
}

template < typename L, typename R >
Expression< L, minus, R > operator - (L const & l, R const & r)
{
    return Expression< L, minus, R >(l, r) ;
}


// this class is a wrapper for int.
class IntHolder
{
public :
    IntHolder(int value)
        : value(value) {} ;
    int operator [] ( size_t ) const
    { return value ; }// always returns same value.
private :
    int const & value ;
} ;



template < typename L >
Expression< L, mul, IntHolder > operator * (L const & l, int r)
{
    return Expression< L, mul, IntHolder >(l, IntHolder(r) ) ;
}

template < typename R >
Expression< IntHolder, mul, R > operator * (int l, R const & r)
{
    return Expression< IntHolder, mul, R >(IntHolder(l), r ) ;
}

#endif

int main()
{
    size_t const N = 10000000 ;
    Array<N> a, b, c, result ;
    
    // 適当に初期化
    for ( int i = 0 ; i != N ; ++i )
    {
        a[i] = 1 ;
        b[i] = 10 ;
        c[i] = 100 ;
    }
    

    result = 123 * a + 40 * b - 2 * c * 2 + b + a + 10 * c * 40 ;

    
    std::cout << typeid(123 * a + 40 * b - 2 * c * 2 + b + a + 10 * c * 40).name() << std::endl ;
    std::cout << result[0] << std::endl ;
    
}

結果は次のとおり。

class Expression<class Expression<class Expression<class Expression<class Expression<class Expression<class IntHolder,struct mul,class Array<10000000> >,structplus,class Expression<class IntHolder,struct mul,class Array<10000000> > >,struct minus,class Expression<class Expression<class IntHolder,struct mul,class Array<10000000> >,struct mul,class IntHolder> >,struct plus,class Array<10000000> >,struct plus,class Array<10000000> >,struct plus,class Expression<class Expression<class IntHolder,struct mul,class Array<10000000> >,struct mul,class IntHolder>>
40134

要するにコンパイル時にテンプレートパラメーターを使って、式の構造を捕らえているに過ぎない。

この途中の型を捕らえておくために、autoが必要だ。

それに、Expression Templateは、何もわざわざ名前をつけるほどの技巧でもないような気がする。Andrew KoenigのRuminations on C++で腐るほど例がある。もっとも、あちらは実行時の話で、こちらはコンパイル時の話だが。

No comments: