[Rust] cuda_sys::cublas使用時にSTATUS_ACCESS_VIOLATIONが出るときの対策

2 min 356 words
Suzuki Shun Placeholder text describing the default author's avatar.

Categories: posts

Tags: Rust

TL;DR: alignmentに気をつけよう.

crate.ioに公開されているcuda_sysクレートにはバグが有り, cuDoubleComplexのアライメントが指定されていない. そのため, 場合によっては, アライメントがずれてSTATUS_ACCESS_VIOLATIONが出ることがある.

例えば, 以下のコードはSTATUS_ACCESS_VIOLATIONが出る (環境によると思う).

use cuda_sys::cublas::*;

struct Wraper {
    handle: cublasHandle_t,
}

impl Wraper {
    fn new() -> Self {
        let mut handle: cublasHandle_t = std::ptr::null_mut();
        unsafe {
            cublasCreate_v2(&mut handle as _);
        }
        Self { handle }
    }

    fn zscal(&self, n: usize, r: f64, x: *mut cuDoubleComplex) {
        let align: f64 = 0.;
        dbg!(&align as *const _);
        let alpha = cuDoubleComplex { x: r, y: align };
        dbg!(&alpha as *const _);
        unsafe {
            cublasZscal_v2(self.handle, n as _, &alpha as *const _ as _, x as _, 1);
        }
    }
}

fn main() {
    let w = Wraper::new();

    let mut p = vec![cuDoubleComplex { x: 1., y: 2. }];
    unsafe {
        let mut dp: *mut cuDoubleComplex = std::ptr::null_mut();
        cuda_sys::cudart::cudaMalloc(
            &mut dp as *mut _ as _,
            std::mem::size_of::<cuDoubleComplex>() * p.len(),
        );
        cuda_sys::cudart::cudaMemcpy(
            dp as _,
            p.as_ptr() as _,
            std::mem::size_of::<cuDoubleComplex>() * p.len(),
            cuda_sys::cudart::cudaMemcpyKind_cudaMemcpyHostToDevice,
        );

        w.zscal(p.len(), 2., dp);

        cuda_sys::cudart::cudaMemcpy(
            p.as_mut_ptr() as _,
            dp as _,
            std::mem::size_of::<cuDoubleComplex>() * p.len(),
            cuda_sys::cudart::cudaMemcpyKind_cudaMemcpyDeviceToHost,
        );
    }

    dbg!(p);
}

zscal内の, align変数を消すと, 正常に動作する.

-        let align: f64 = 0.;
-        dbg!(&align as *const _);
-        let alpha = cuDoubleComplex { x: r, y: align };
+        // let align: f64 = 0.;
+        // dbg!(&align as *const _);
+        let alpha = cuDoubleComplex { x: r, y: 0. };

cuda-sys version 0.3.0-alphaでは, この問題は修正されているようだが, crate.ioにはまだ公開されていない. 簡単な解決策は, アライメントを指定した構造体でラップすること.

#[repr(C)]
#[repr(align(16))]
struct CuDoubleComplexWrapper(cuDoubleComplex);

...

        let align: f64 = 0.;
        dbg!(&align as *const _);
        let alpha = CuDoubleComplexWrapper(cuDoubleComplex { x: r, y: align });