Skip to content

Commit

Permalink
ZLUDA v3.8.7 (#66)
Browse files Browse the repository at this point in the history
* Add dummy cuFFTW library.

* Bump version.

* Implement fft functions required to run torch fftn, ifftn, and rfftn.
  • Loading branch information
lshqqytiger authored Jan 15, 2025
1 parent d60bddb commit c4994b3
Show file tree
Hide file tree
Showing 14 changed files with 1,049 additions and 193 deletions.
6 changes: 5 additions & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

131 changes: 66 additions & 65 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,65 +1,66 @@
[workspace]

resolver = "2"

# Remember to also update the project's Cargo.toml
# if it's a top-level project
members = [
"atiadlxx-sys",
"comgr",
"cuda_base",
"cuda_types",
"detours-sys",
"ext/llvm-sys.rs",
"hip_common",
"hip_runtime-sys",
"hipblaslt-sys",
"hipfft-sys",
"hiprt-sys",
"miopen-sys",
"offline_compiler",
"optix_base",
"optix_dump",
"process_address_table",
"ptx",
"rocblas-sys",
"rocm_smi-sys",
"rocsparse-sys",
"xtask",
"zluda",
"zluda_api",
"zluda_blas",
"zluda_blaslt",
"zluda_ccl",
"zluda_dark_api",
"zluda_dnn",
"zluda_dump",
"zluda_fft",
"zluda_inject",
"zluda_lib",
"zluda_llvm",
"zluda_ml",
"zluda_redirect",
"zluda_rt",
"zluda_rtc",
"zluda_runtime",
"zluda_sparse",
]

# Cargo does not support OS-specific or profile-specific
# targets. We keep list here to bare minimum and rely on xtask
default-members = [
"zluda_lib",
"zluda_ml",
"zluda_inject",
"zluda_redirect"
]

[profile.dev.package.blake3]
opt-level = 3

[profile.dev.package.lz4-sys]
opt-level = 3

[profile.dev.package.xtask]
opt-level = 2
[workspace]

resolver = "2"

# Remember to also update the project's Cargo.toml
# if it's a top-level project
members = [
"atiadlxx-sys",
"comgr",
"cuda_base",
"cuda_types",
"detours-sys",
"ext/llvm-sys.rs",
"hip_common",
"hip_runtime-sys",
"hipblaslt-sys",
"hipfft-sys",
"hiprt-sys",
"miopen-sys",
"offline_compiler",
"optix_base",
"optix_dump",
"process_address_table",
"ptx",
"rocblas-sys",
"rocm_smi-sys",
"rocsparse-sys",
"xtask",
"zluda",
"zluda_api",
"zluda_blas",
"zluda_blaslt",
"zluda_ccl",
"zluda_dark_api",
"zluda_dnn",
"zluda_dump",
"zluda_fft",
"zluda_fftw",
"zluda_inject",
"zluda_lib",
"zluda_llvm",
"zluda_ml",
"zluda_redirect",
"zluda_rt",
"zluda_rtc",
"zluda_runtime",
"zluda_sparse",
]

# Cargo does not support OS-specific or profile-specific
# targets. We keep list here to bare minimum and rely on xtask
default-members = [
"zluda_lib",
"zluda_ml",
"zluda_inject",
"zluda_redirect"
]

[profile.dev.package.blake3]
opt-level = 3

[profile.dev.package.lz4-sys]
opt-level = 3

[profile.dev.package.xtask]
opt-level = 2
1 change: 1 addition & 0 deletions hipblaslt-sys/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ impl hipblasOperation_t {
impl hipblasOperation_t {
pub const HIPBLAS_OP_C: hipblasOperation_t = hipblasOperation_t(113);
}
#[allow(non_camel_case_types)]
#[repr(transparent)]
#[derive(Copy, Clone, Hash, PartialEq, Eq)]
pub struct hipblasOperation_t(pub ::std::os::raw::c_int);
107 changes: 12 additions & 95 deletions zluda_blas/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
#![allow(warnings)]
#[allow(warnings)]
mod common;
#[allow(warnings)]
mod cublas;
#[allow(warnings)]
mod cublasxt;

pub use common::*;
Expand All @@ -13,7 +15,7 @@ use rocsolver_sys::{
rocsolver_cgetrf_batched, rocsolver_cgetri_outofplace_batched, rocsolver_dgetrs_batched,
rocsolver_sgetrs_batched, rocsolver_zgetrf_batched, rocsolver_zgetri_outofplace_batched,
};
use std::{mem, ptr};
use std::ptr;

#[cfg(debug_assertions)]
pub(crate) fn unsupported() -> cublasStatus_t {
Expand Down Expand Up @@ -223,61 +225,20 @@ unsafe fn set_stream(handle: cublasHandle_t, stream_id: cudaStream_t) -> cublasS
) -> CUresult>(b"cuGetExportTable\0")
.unwrap();
let mut export_table = ptr::null();
(cu_get_export_table)(&mut export_table, &zluda_dark_api::ZludaExt::GUID);
assert_eq!(
(cu_get_export_table)(&mut export_table, &zluda_dark_api::ZludaExt::GUID),
CUresult::CUDA_SUCCESS
);
let zluda_ext = zluda_dark_api::ZludaExt::new(export_table);
let stream: Result<_, _> = zluda_ext.get_hip_stream(stream_id as _).into();
to_cuda(rocblas_set_stream(handle as _, stream.unwrap() as _))
}

fn set_math_mode(handle: cublasHandle_t, mode: cublasMath_t) -> cublasStatus_t {
fn set_math_mode(_handle: cublasHandle_t, _mode: cublasMath_t) -> cublasStatus_t {
// llama.cpp uses CUBLAS_TF32_TENSOR_OP_MATH
cublasStatus_t::CUBLAS_STATUS_SUCCESS
}

unsafe fn sgemm(
transa: std::ffi::c_char,
transb: std::ffi::c_char,
m: i32,
n: i32,
k: i32,
alpha: f32,
a: *const f32,
lda: i32,
b: *const f32,
ldb: i32,
beta: f32,
c: *mut f32,
ldc: i32,
) -> cublasStatus_t {
let mut handle = mem::zeroed();
let mut status = to_cuda(rocblas_create_handle(handle));
if status != cublasStatus_t::CUBLAS_STATUS_SUCCESS {
return status;
}
let transa = op_from_cuda(cublasOperation_t(transa as _));
let transb = op_from_cuda(cublasOperation_t(transb as _));
status = to_cuda(rocblas_sgemm(
handle.cast(),
transa,
transb,
m,
n,
k,
&alpha,
a,
lda,
b,
ldb,
&beta,
c,
ldc,
));
if status != cublasStatus_t::CUBLAS_STATUS_SUCCESS {
return status;
}
to_cuda(rocblas_destroy_handle(*handle))
}

unsafe fn sgemm_v2(
handle: cublasHandle_t,
transa: cublasOperation_t,
Expand Down Expand Up @@ -495,7 +456,7 @@ unsafe fn gemm_ex(
))
}

fn to_algo(algo: cublasGemmAlgo_t) -> rocblas_gemm_algo_ {
fn to_algo(_algo: cublasGemmAlgo_t) -> rocblas_gemm_algo_ {
// only option
rocblas_gemm_algo::rocblas_gemm_algo_standard
}
Expand Down Expand Up @@ -807,7 +768,7 @@ unsafe fn sgetrs_batched(
dev_ipiv: *const i32,
b: *const *mut f32,
ldb: i32,
info: *mut i32,
_info: *mut i32,
batch_size: i32,
) -> cublasStatus_t {
let trans = op_from_cuda_for_solver(trans);
Expand Down Expand Up @@ -837,7 +798,7 @@ unsafe fn dgetrs_batched(
dev_ipiv: *const i32,
b: *const *mut f64,
ldb: i32,
info: *mut i32,
_info: *mut i32,
batch_size: i32,
) -> cublasStatus_t {
let trans = op_from_cuda_for_solver(trans);
Expand Down Expand Up @@ -1048,50 +1009,6 @@ unsafe fn dger(
))
}

unsafe fn dgemm(
transa: std::ffi::c_char,
transb: std::ffi::c_char,
m: i32,
n: i32,
k: i32,
alpha: f64,
a: *const f64,
lda: i32,
b: *const f64,
ldb: i32,
beta: f64,
c: *mut f64,
ldc: i32,
) -> cublasStatus_t {
let mut handle = mem::zeroed();
let mut status = to_cuda(rocblas_create_handle(handle));
if status != cublasStatus_t::CUBLAS_STATUS_SUCCESS {
return status;
}
let transa = op_from_cuda(cublasOperation_t(transa as _));
let transb = op_from_cuda(cublasOperation_t(transb as _));
status = to_cuda(rocblas_dgemm(
handle.cast(),
transa,
transb,
m,
n,
k,
&alpha,
a,
lda,
b,
ldb,
&beta,
c,
ldc,
));
if status != cublasStatus_t::CUBLAS_STATUS_SUCCESS {
return status;
}
to_cuda(rocblas_destroy_handle(*handle))
}

unsafe fn dgemm_v2(
handle: *mut cublasContext,
transa: cublasOperation_t,
Expand Down
4 changes: 2 additions & 2 deletions zluda_fft/src/cufft.rs
Original file line number Diff line number Diff line change
Expand Up @@ -380,15 +380,15 @@ pub unsafe extern "system" fn cufftSetWorkArea(
plan: cufftHandle,
workArea: *mut ::std::os::raw::c_void,
) -> cufftResult {
crate::unsupported()
crate::set_work_area(plan, workArea)
}

#[no_mangle]
pub unsafe extern "system" fn cufftSetAutoAllocation(
plan: cufftHandle,
autoAllocate: ::std::os::raw::c_int,
) -> cufftResult {
crate::unsupported()
crate::set_auto_allocation(plan, autoAllocate)
}

#[no_mangle]
Expand Down
19 changes: 17 additions & 2 deletions zluda_fft/src/cufftxt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -376,7 +376,22 @@ pub unsafe extern "system" fn cufftXtMakePlanMany(
workSize: *mut usize,
executiontype: cudaDataType,
) -> cufftResult {
crate::unsupported()
crate::xt_make_plan_many(
plan,
rank,
n,
inembed,
istride,
idist,
inputtype,
onembed,
ostride,
odist,
outputtype,
batch,
workSize,
executiontype,
)
}

#[no_mangle]
Expand Down Expand Up @@ -406,7 +421,7 @@ pub unsafe extern "system" fn cufftXtExec(
output: *mut ::std::os::raw::c_void,
direction: ::std::os::raw::c_int,
) -> cufftResult {
crate::unsupported()
crate::xt_exec(plan, input, output, direction)
}

#[no_mangle]
Expand Down
Loading

0 comments on commit c4994b3

Please sign in to comment.