DDR爱好者之家 Design By 杰米
以下面这个例子作为教程,实现功能是element-wise add;
(pytorch中想调用cuda模块,还是用另外使用C编写接口脚本)
第一步:cuda编程的源文件和头文件
// mathutil_cuda_kernel.cu // 头文件,最后一个是cuda特有的 #include <curand.h> #include <stdio.h> #include <math.h> #include <float.h> #include "mathutil_cuda_kernel.h" // 获取GPU线程通道信息 dim3 cuda_gridsize(int n) { int k = (n - 1) / BLOCK + 1; int x = k; int y = 1; if(x > 65535) { x = ceil(sqrt(k)); y = (n - 1) / (x * BLOCK) + 1; } dim3 d(x, y, 1); return d; } // 这个函数是cuda执行函数,可以看到细化到了每一个元素 __global__ void broadcast_sum_kernel(float *a, float *b, int x, int y, int size) { int i = (blockIdx.x + blockIdx.y * gridDim.x) * blockDim.x + threadIdx.x; if(i >= size) return; int j = i % x; i = i / x; int k = i % y; a[IDX2D(j, k, y)] += b[k]; } // 这个函数是与c语言函数链接的接口函数 void broadcast_sum_cuda(float *a, float *b, int x, int y, cudaStream_t stream) { int size = x * y; cudaError_t err; // 上面定义的函数 broadcast_sum_kernel<<<cuda_gridsize(size), BLOCK, 0, stream>(a, b, x, y, size); err = cudaGetLastError(); if (cudaSuccess != err) { fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err)); exit(-1); } }
#ifndef _MATHUTIL_CUDA_KERNEL #define _MATHUTIL_CUDA_KERNEL #define IDX2D(i, j, dj) (dj * i + j) #define IDX3D(i, j, k, dj, dk) (IDX2D(IDX2D(i, j, dj), k, dk)) #define BLOCK 512 #define MAX_STREAMS 512 #ifdef __cplusplus extern "C" { #endif void broadcast_sum_cuda(float *a, float *b, int x, int y, cudaStream_t stream); #ifdef __cplusplus } #endif #endif
第二步:C编程的源文件和头文件(接口函数)
// mathutil_cuda.c // THC是pytorch底层GPU库 #include <THC/THC.h> #include "mathutil_cuda_kernel.h" extern THCState *state; int broadcast_sum(THCudaTensor *a_tensor, THCudaTensor *b_tensor, int x, int y) { float *a = THCudaTensor_data(state, a_tensor); float *b = THCudaTensor_data(state, b_tensor); cudaStream_t stream = THCState_getCurrentStream(state); // 这里调用之前在cuda中编写的接口函数 broadcast_sum_cuda(a, b, x, y, stream); return 1; }
int broadcast_sum(THCudaTensor *a_tensor, THCudaTensor *b_tensor, int x, int y);
第三步:编译,先编译cuda模块,再编译接口函数模块(不能放在一起同时编译)
nvcc -c -o mathutil_cuda_kernel.cu.o mathutil_cuda_kernel.cu -x cu -Xcompiler -fPIC -arch=sm_52
import os import torch from torch.utils.ffi import create_extension this_file = os.path.dirname(__file__) sources = [] headers = [] defines = [] with_cuda = False if torch.cuda.is_available(): print('Including CUDA code.') sources += ['src/mathutil_cuda.c'] headers += ['src/mathutil_cuda.h'] defines += [('WITH_CUDA', None)] with_cuda = True this_file = os.path.dirname(os.path.realpath(__file__)) extra_objects = ['src/mathutil_cuda_kernel.cu.o'] # 这里是编译好后的.o文件位置 extra_objects = [os.path.join(this_file, fname) for fname in extra_objects] ffi = create_extension( '_ext.cuda_util', headers=headers, sources=sources, define_macros=defines, relative_to=__file__, with_cuda=with_cuda, extra_objects=extra_objects ) if __name__ == '__main__': ffi.build()
第四步:调用cuda模块
from _ext import cuda_util #从对应路径中调用编译好的模块 a = torch.randn(3, 5).cuda() b = torch.randn(3, 1).cuda() mathutil.broadcast_sum(a, b, *map(int, a.size())) # 上面等价于下面的效果: a = torch.randn(3, 5) b = torch.randn(3, 1) a += b
以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持。
DDR爱好者之家 Design By 杰米
广告合作:本站广告合作请联系QQ:858582 申请时备注:广告合作(否则不回)
免责声明:本站资源来自互联网收集,仅供用于学习和交流,请遵循相关法律法规,本站一切资源不代表本站立场,如有侵权、后门、不妥请联系本站删除!
免责声明:本站资源来自互联网收集,仅供用于学习和交流,请遵循相关法律法规,本站一切资源不代表本站立场,如有侵权、后门、不妥请联系本站删除!
DDR爱好者之家 Design By 杰米
暂无评论...
RTX 5090要首发 性能要翻倍!三星展示GDDR7显存
三星在GTC上展示了专为下一代游戏GPU设计的GDDR7内存。
首次推出的GDDR7内存模块密度为16GB,每个模块容量为2GB。其速度预设为32 Gbps(PAM3),但也可以降至28 Gbps,以提高产量和初始阶段的整体性能和成本效益。
据三星表示,GDDR7内存的能效将提高20%,同时工作电压仅为1.1V,低于标准的1.2V。通过采用更新的封装材料和优化的电路设计,使得在高速运行时的发热量降低,GDDR7的热阻比GDDR6降低了70%。
更新日志
2025年01月11日
2025年01月11日
- 小骆驼-《草原狼2(蓝光CD)》[原抓WAV+CUE]
- 群星《欢迎来到我身边 电影原声专辑》[320K/MP3][105.02MB]
- 群星《欢迎来到我身边 电影原声专辑》[FLAC/分轨][480.9MB]
- 雷婷《梦里蓝天HQⅡ》 2023头版限量编号低速原抓[WAV+CUE][463M]
- 群星《2024好听新歌42》AI调整音效【WAV分轨】
- 王思雨-《思念陪着鸿雁飞》WAV
- 王思雨《喜马拉雅HQ》头版限量编号[WAV+CUE]
- 李健《无时无刻》[WAV+CUE][590M]
- 陈奕迅《酝酿》[WAV分轨][502M]
- 卓依婷《化蝶》2CD[WAV+CUE][1.1G]
- 群星《吉他王(黑胶CD)》[WAV+CUE]
- 齐秦《穿乐(穿越)》[WAV+CUE]
- 发烧珍品《数位CD音响测试-动向效果(九)》【WAV+CUE】
- 邝美云《邝美云精装歌集》[DSF][1.6G]
- 吕方《爱一回伤一回》[WAV+CUE][454M]