jackee777のブログ

情報系学生のつぶやき

cython で numpy の dot を学ぶ - part 1

※こちらの記事は,暇があるたびに詳しくなっていきます(2019/06/27)

はじめに

どうも,cython の良い導入ってないよなあって感じてるものです.最近は cpp も使えるらしく,非常に便利だと思います. 当方,O'REILLY の Cython は昔読んだかもですが,読んだかを忘れました.

中身

cython の導入にはややこしいことがあるので,それを理解する.そのために,numpy.dot を実装する.全部のコードはこちら

ファイル構成

まず,ファイル構成を以下のようにして,setup.py を書きます.

cython_lib
- __init__.py
- calculation.pxd
- calculation.pyx
setup.py

これだけで基本的には十分.python setup.py install を実行すると,cython_lib/calculation.c が出来ますが,自動生成なので読むものじゃないです.

setup.py を書く

参考はこちら

from Cython.Distutils import build_ext
from Cython.Build import cythonize
import numpy as np
from setuptools import setup, Extension

ext_modules = [
    Extension('cython_lib.calculation', # module name
              sources=['cython_lib/calculation.pyx'], # necessary source code(.pyx and .c)
              include_dirs=[np.get_include()], # include files
              compiler_directives={'language_level' : "3"},
              ),
]

setup(
    name='cython_lib',
    packages=['cython_lib'],
    version='0.0.1',
    ext_modules=cythonize(ext_modules),
    cmdclass={'build_ext': build_ext}
)

cython_lib/calculation.pyx を書く

とりあえず,dot をテストするだけなら,ただの python です.

#!/usr/bin/env cython
# cython: boundscheck=False
# cython: wraparound=False
# cython: cdivision=True
# cython: embedsignature=True
# coding: utf-8

cimport cython
import numpy as np
cimport numpy as np

# 1×1 が答えの時
cpdef REAL_t cnum_dot_one(np.ndarray X0, np.ndarray X1):
    return np.dot(X0, X1)

# n×n が答えの時
cpdef np.ndarray cnum_dot(np.ndarray X0, np.ndarray X1):
    return np.dot(X0, X1)

 

cython_lib/calculation.pxd を書く

.pxd ファイルは c 言語で言う header ファイルと同じなので,1×1 が答えが答えとなる時の REAL_t の型だけ定めます.

import numpy as np
cimport numpy as np

ctypedef np.float32_t REAL_t

速度の計測

これはただの numpy なので,速度に違いはない.

本題

cython を適当に書くだけだと簡単なのはわかったと思うので,じゃあちゃんと c 言語で書こうと思う.としても,dot の比較をしても,numpy 相手では速度が上がらないという同じ結論になります. (次元数による比較によると,次元数が小さい時は jit が最速ですがそうでない時は numpy が速いそう.)

なんで numpy がこんな速いのかを理解していく過程で,cython を学びましょう.大体学べる.

とりあえず,ここまで

予定としては,openblas を使って numpy に追いつくところまでやります.

本当は wrapper を作る必要があるので,そこは勉強していきたいですね.また,openblas も scipy の linalg も元は fortran なので,fortran からの適用を学ぶ必要もあります(fortran は読めもしないけど).