8000 Stream support for BLAS Handles. by colesbury · Pull Request #181 · torch/cutorch · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

Stream support for BLAS Handles. #181

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 24, 2015
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading < 10000 /div>
Diff view
Diff view
19 changes: 17 additions & 2 deletions FFI.lua
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,30 @@ if ok then
local cdefs = [[
typedef struct CUstream_st *cudaStream_t;

struct cublasContext;
typedef struct cublasContext *cublasHandle_t;
typedef struct CUhandle_st *cublasHandle_t;

typedef struct _THCCudaResourcesPerDevice {
cudaStream_t* streams;
cublasHandle_t* blasHandles;
size_t scratchSpacePerStream;
void** devScratchSpacePerStream;
} THCCudaResourcesPerDevice;


typedef struct THCState
{
struct THCRNGState* rngState;
struct THCBlasState* blasState;
struct cudaDeviceProp* deviceProperties;
cudaStream_t currentStream;
cudaStream_t** streamsPerDevice;
cublasHandle_t currentBlasHandle;
THCCudaResourcesPerDevice* resourcesPerDevice;
int numDevices;
int numUserStreams;
int numUserBlasHandles;
int currentPerDeviceStream;
int currentPerDeviceBlasHandle;
struct THAllocator* cudaHostAllocator;
} THCState;

Expand All @@ -27,6 +41,7 @@ typedef struct THCudaStorage
char flag;
THAllocator *allocator;
void *allocatorContext;
struct THCudaStorage *view;
} THCudaStorage;

typedef struct THCudaTensor
Expand Down
73 changes: 72 additions & 1 deletion init.c
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,23 @@ static int cutorch_reserveStreams(lua_State *L)
return 0;
}

/*
Usage:
cutorch.reserveBlasHandles(n)
Allocates n blasHandles for every device present. If fewer than
n blasHandles are currently allocated, an additional number will be added.
If more than n blasHandles are currently allocated, does nothing.
Unlike for streams, there is no default blasHandle.
*/
static int cutorch_reserveBlasHandles(lua_State *L)
{
THCState *state = cutorch_getstate(L);
int numHandles = (int) luaL_checknumber(L, 1);
THCState_reserveBlasHandles(state, numHandles);

return 0;
}

/*
Usage:
n = cutorch.getNumStreams()
Expand All @@ -224,6 +241,20 @@ static int cutorch_getNumStreams(lua_State *L)
return 1;
}

/*
Usage:
n = cutorch.getNumBlasHandles()
Returns the number of user blasHandles allocated for every device present.
By default, is 1.
*/
static int cutorch_getNumBlasHandles(lua_State *L)
{
THCState *state = cutorch_getstate(L);
lua_pushnumber(L, THCState_getNumBlasHandles(state));

return 1;
}

/*
Usage:
cutorch.setStream(n)
Expand All @@ -247,6 +278,28 @@ static int cutorch_setStream(lua_State *L)
return 0;
}

/*
Usage:
cutorch.setBlasHandle(n)
For all devices, sets the current blasHandle in use to the index
specified. e.g.,
---
cutorch.setDevice(1)
cutorch.setBlasHandle(3)
-- device 1 blasHandle 3 in use here
cutorch.setDevice(2)
-- device 2 blasHandle 3 in use here
---
*/
static int cutorch_setBlasHandle(lua_State *L)
{
THCState *state = cutorch_getstate(L);
int handle = (int) luaL_checknumber(L, 1);
THCState_setBlasHandleForCurrentDevice(state, handle);

return 0;
}

/*
Usage:
n = cutorch.getStream()
Expand All @@ -262,6 +315,20 @@ static int cutorch_getStream(lua_State *L)
return 1;
}

/*
Usage:
n = cutorch.getBlasHandle()
Returns the current blasHandle for all devices in use (as previously
set via cutorch.setBlasHandle(n).
*/
static int cutorch_getBlasHandle(lua_State *L)
{
THCState *state = cutorch_getstate(L);
lua_pushnumber(L, THCState_getCurrentBlasHandleIndex(state));

return 1;
}

/*
Usage:
cutorch.setDefaultStream()
Expand Down Expand Up @@ -537,10 +604,10 @@ static int cutorch_setDevice(lua_State *L)
int device = (int)luaL_checknumber(L, 1)-1;
THCudaCheck(cudaSetDevice(device));
THCRandom_setGenerator(state, device);
THCudaBlas_setHandle(state, device);

/* The stream is per device, so update the stream as well */
THCState_setStream(state, device, THCState_getCurrentStreamIndex(state));
THCState_setBlasHandle(state, device, THCState_getCurrentBlasHandleIndex(state));

return 0;
}
Expand Down Expand Up @@ -658,6 +725,10 @@ static int cutorch_getState(lua_State *L)

static const struct luaL_Reg cutorch_stuff__ [] = {
{"synchronize", cutorch_synchronize},
{"reserveBlasHandles", cutorch_reserveBlasHandles},
{"getNumBlasHandles", cutorch_getNumBlasHandles},
{"setBlasHandle", cutorch_setBlasHandle},
{"getBlasHandle", cutorch_getBlasHandle},
{"reserveStreams", cutorch_reserveStreams},
{"getNumStreams", cutorch_getNumStreams},
{"setStream", cutorch_setStream},
Expand Down
3 changes: 2 additions & 1 deletion lib/THC/THCApply.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,8 @@ inline bool getApplyGrid(THCState* state, long totalElements, dim3& grid) {

// 16 warps per block * 4 per SM gives 64 warps per SM at maximum,
// which seems to be a good sweetspot for latency hiding
grid = dim3(min(DIVUP(totalElements, (long long) THC_APPLY_THREADS_PER_BLOCK),
grid = dim3(min((long long) THCCeilDiv(totalElements,
(long) THC_APPLY_THREADS_PER_BLOCK),
4LL * numSM));
return true;
}
Expand Down
56 changes: 9 additions & 47 deletions lib/THC/THCBlas.cu
Original file line number Diff line number Diff line change
@@ -1,44 +1,6 @@
#include "THCBlas.h"
#include "THCGeneral.h"

void THCudaBlas_init(THCState *state, int devices, int device)
{
THCBlasState *blas_state = state->blasState;
blas_state->handles = (cublasHandle_t *)malloc(devices * sizeof(cublasHandle_t));
for (int i = 0; i < devices; i++) {
// Create handle on each device:
cudaSetDevice(i);
cublasCreate(&blas_state->handles[i]);
}

// Set current handle:
blas_state->current_handle = &blas_state->handles[device];
blas_state->n_devices = devices;

// Restore device:
cudaSetDevice(device);
}

void THCudaBlas_shutdown(THCState *state)
{
THCBlasState *blas_state = state->blasState;
for (int i = 0; i < blas_state->n_devices; i++) {
cublasDestroy(blas_state->handles[i]);
}
free(blas_state->handles);
}

void THCudaBlas_setHandle(THCState *state, int device)
{
THCBlasState *blas_state = state->blasState;
blas_state->current_handle = &blas_state->handles[device];
}

void THCudaBlas_setStream(THCState *state, int device, cudaStream_t stream)
{
THCublasCheck(cublasSetStream(state->blasState->handles[device], stream));
}

void THCudaBlas_swap(THCState *state, long n, float *x, long incx, float *y, long incy)
{
if(n == 1)
Expand All @@ -52,7 +14,7 @@ void THCudaBlas_swap(THCState *state, long n, float *x, long incx, float *y, lon
int i_n = (int)n;
int i_incx = (int)incx;
int i_incy = (int)incy;
THCublasCheck(cublasSswap(*state->blasState->current_handle, i_n, x, i_incx, y, i_incy));
THCublasCheck(cublasSswap(THCState_getCurrentBlasHandle(state), i_n, x, i_incx, y, i_incy));
return;
}
THError("Cublas_swap only supports n, incx and"
Expand All @@ -68,7 +30,7 @@ void THCudaBlas_scal(THCState *state, long n, float a, float *x, long incx)
{
int i_n = (int)n;
int i_incx = (int)incx;
THCublasCheck(cublasSscal(*state->blasState->current_handle, i_n, &a, x, i_incx));
THCublasCheck(cublasSscal(THCState_getCurrentBlasHandle(state), i_n, &a, x, i_incx));
return;
}
THError("Cublas_scal only supports n and incx "
Expand All @@ -88,7 +50,7 @@ void THCudaBlas_copy(THCState *state, long n, float *x, long incx, float *y, lon
int i_n = (int)n;
int i_incx = (int)incx;
int i_incy = (int)incy;
THCublasCheck(cublasScopy(*state->blasState->current_handle, i_n, x, i_incx, y, i_incy));
THCublasCheck(cublasScopy(THCState_getCurrentBlasHandle(state), i_n, x, i_incx, y, i_incy));
return;
}

Expand All @@ -109,7 +71,7 @@ void THCudaBlas_axpy(THCState *state, long n, float a, float *x, long incx, floa
int i_n = (int)n;
int i_incx = (int)incx;
int i_incy = (int)incy;
THCublasCheck(cublasSaxpy(*state->blasState->current_handle, i_n, &a, x, i_incx, y, i_incy));
THCublasCheck(cublasSaxpy(THCState_getCurrentBlasHandle(state), i_n, &a, x, i_incx, y, i_incy));
return;
}

Expand All @@ -131,7 +93,7 @@ float THCudaBlas_dot(THCState *state, long n, float *x, long incx, float *y, lon
int i_incx = (int)incx;
int i_incy = (int)incy;
float result;
THCublasCheck(cublasSdot(*state->blasState->current_handle, i_n, x, i_incx, y, i_incy, &result));
THCublasCheck(cublasSdot(THCState_getCurrentBlasHandle(state), i_n, x, i_incx, y, i_incy, &result));
cudaDeviceSynchronize();
return result;
}
Expand Down Expand Up @@ -162,7 +124,7 @@ void THCudaBlas_gemv(THCState *state, char trans, long m, long n, float alpha, f
int i_incx = (int)incx;
int i_incy = (int)incy;

THCublasCheck(cublasSgemv(*state->blasState->current_handle, op, i_m, i_n, &alpha, a, i_lda, x, i_incx, &beta, y, i_incy));
THCublasCheck(cublasSgemv(THCState_getCurrentBlasHandle(state), op, i_m, i_n, &alpha, a, i_lda, x, i_incx, &beta, y, i_incy));
return;
}
THError("Cublas_gemv only supports m, n, lda, incx, incy"
Expand All @@ -182,7 +144,7 @@ void THCudaBlas_ger(THCState *state, long m, long n, float alpha, float *x, long
int i_incx = (int)incx;
int i_incy = (int)incy;

THCublasCheck(cublasSger(*state->blasState->current_handle, i_m, i_n, &alpha, x, i_incx, y, i_incy, a, i_lda));
THCublasCheck(cublasSger(THCState_getCurrentBlasHandle(state), i_m, i_n, &alpha, x, i_incx, y, i_incy, a, i_lda));
return;
}
THError("Cublas_ger only supports m, n, lda, incx, incy"
Expand Down Expand Up @@ -246,7 +208,7 @@ void THCudaBlas_gemm(THCState *state, char transa, char transb, long m, long n,
int i_ldb = (int)ldb;
int i_ldc = (int)ldc;

THCublasCheck(cublasSgemm(*state->blasState->current_handle, opa, opb, i_m, i_n, i_k, &alpha, a, i_lda, b, i_ldb, &beta, c, i_ldc));
THCublasCheck(cublasSgemm(THCState_getCurrentBlasHandle(state), opa, opb, i_m, i_n, i_k, &alpha, a, i_lda, b, i_ldb, &beta, c, i_ldc));
return;
}
THError("Cublas_gemm only supports m, n, k, lda, ldb, ldc"
Expand All @@ -267,7 +229,7 @@ void THCudaBlas_gemmBatched(THCState *state, char transa, char transb, long m, l
cublasOperation_t opa = convertTransToCublasOperation(transa);
cublasOperation_t opb = convertTransToCublasOperation(transb);

THCublasCheck(cublasSgemmBatched(*state->blasState->current_handle,
THCublasCheck(cublasSgemmBatched(THCState_getCurrentBlasHandle(state),
opa, opb, (int)m, (int)n, (int)k,
&alpha, a, (int)lda, b, (int)ldb, &beta, c, (int)ldc,
(int)batchCount));
Expand Down
6 changes: 0 additions & 6 deletions lib/THC/THCBlas.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,6 @@

#include "THCGeneral.h"

typedef struct THCBlasState {
cublasHandle_t* handles;
cublasHandle_t* current_handle;
int n_devices;
} THCBlasState;

/* Level 1 */
THC_API void THCudaBlas_swap(THCState *state, long n, float *x, long incx, float *y, long incy);
THC_API void THCudaBlas_scal(THCState *state, long n, float a, float *x, long incx);
Expand Down
15 changes: 15 additions & 0 deletions lib/THC/THCDeviceUtils.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
#ifndef THC_DEVICE_UTILS_INC
#define THC_DEVICE_UTILS_INC

/* The largest consecutive integer representable in float32 (2^24) */
#define FLOAT32_MAX_CONSECUTIVE_INT 16777216.0f

/**
Computes ceil(a / b)
*/
template <typename T>
__host__ __device__ __forceinline__ T THCCeilDiv(T a, T b) {
return (a + b - 1) / b;
}

#endif // THC_DEVICE_UTILS_INC
Loading
0