10000 Add C++ bindings for cuDNN by colesbury · Pull Request #167 · pytorch/pytorch · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

Add C++ bindings for cuDNN #167

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 26, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ torch/lib/build
torch/lib/tmp_install
torch/lib/include
torch/lib/torch_shm_manager
torch/csrc/cudnn/cuDNN.cpp
torch/csrc/nn/THNN.cwrap
torch/csrc/nn/THNN.cpp
torch/csrc/nn/THCUNN.cwrap
Expand Down
18 changes: 18 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

# TODO: make this more robust
WITH_CUDA = os.path.exists('/Developer/NVIDIA/CUDA-7.5/include') or os.path.exists('/usr/local/cuda/include')
WITH_CUDNN = WITH_CUDA
DEBUG = False

################################################################################
Expand Down Expand Up @@ -81,10 +82,15 @@ def run(self):
from tools.cwrap.plugins.AutoGPU import AutoGPU
from tools.cwrap.plugins.BoolOption import BoolOption
from tools.cwrap.plugins.KwargsPlugin import KwargsPlugin
from tools.cwrap.plugins.NullableArguments import NullableArguments
from tools.cwrap.plugins.CuDNNPlugin import CuDNNPlugin
cwrap('torch/csrc/generic/TensorMethods.cwrap', plugins=[
AutoGPU(condition='IS_CUDA'), THPLongArgsPlugin(), BoolOption(),
THPPlugin(), ArgcountSortPlugin(), KwargsPlugin(),
])
cwrap('torch/csrc/cudnn/cuDNN.cwrap', plugins=[
CuDNNPlugin(), NullableArguments()
])
# It's an old-style class in Python 2.7...
setuptools.command.build_ext.build_ext.run(self)

Expand Down Expand Up @@ -192,6 +198,18 @@ def run(self):
"torch/csrc/cuda/serialization.cpp",
]

if WITH_CUDNN:
main_libraries += ['cudnn']
main_sources += [
"torch/csrc/cudnn/Module.cpp",
"torch/csrc/cudnn/Conv.cpp",
"torch/csrc/cudnn/cuDNN.cpp",
"torch/csrc/cudnn/Types.cpp",
"torch/csrc/cudnn/Handles.cpp",
"torch/csrc/cudnn/CppWrapper.cpp",
]
extra_compile_args += ['-DWITH_CUDNN']

if DEBUG:
extra_compile_args += ['-O0', '-g']
extra_link_args += ['-O0', '-g']
Expand Down
15 changes: 13 additions & 2 deletions tools/cwrap/cwrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,13 @@ def get_wrapper_template(self, declaration):
return self.search_plugins('get_wrapper_template', (declaration,), lambda _: None)

def get_arg_accessor(self, arg, option):
return self.search_plugins('get_arg_accessor', (arg, option), lambda arg,_: 'PyTuple_GET_ITEM(args, {})'.format(arg['idx']))
def wrap_accessor(arg, _):
if arg.get('idx') is None:

This comment was marked as off-topic.

This comment was marked as off-topic.

raise RuntimeError("Missing accessor for '{} {}'".format(
arg['type'], arg['name']))
return 'PyTuple_GET_ITEM(args, {})'.format(arg['idx'])

return self.search_plugins('get_arg_accessor', (arg, option), wrap_accessor)

def generate_wrapper(self, declaration):
wrapper = ''
Expand All @@ -153,7 +159,12 @@ def map_selected_arguments(self, base_fn_name, plugin_fn_name, option, arguments
result = []
for arg in arguments:
accessor = self.get_arg_accessor(arg, option)
res = getattr(self, base_fn_name)(arg, option).substitute(arg=accessor)
tmpl = getattr(self, base_fn_name)(arg, option)
if tmpl is None:
fn = 'check' if base_fn_name == 'get_type_check' else 'unpack'
raise RuntimeError("Missing type {} for '{} {}'".format(
fn, arg['type'], arg['name']))
res = tmpl.substitute(arg=accessor)
for plugin in self.plugins:
res = getattr(plugin, plugin_fn_name)(res, arg, accessor)
result.append(res)
Expand Down
159 changes: 159 additions & 0 deletions tools/cwrap/plugins/CuDNNPlugin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
from string import Template
from copy import deepcopy
from . import CWrapPlugin
from itertools import product

class CuDNNPlugin(CWrapPlugin):

TYPE_UNPACK = {
'THTensor*': Template('((THPVoidTensor*)$arg)->cdata'),
'int': Template('THPUtils_unpackLong($arg)'),
'cudnnDataType_t': Template('$arg'),
'cudnnHandle_t': Template('$arg'),
'Convolution*': Template('(Convolution*)THPWrapper_get($arg)'),
'bool': Template('$arg == Py_True'),
}

TYPE_CHECK = {
'Convolution*': Template('THPWrapper_check($arg)'),
'THTensor*': Template('(PyObject*)Py_TYPE($arg) == tensorClass'),
'int': Template('THPUtils_checkLong($arg)'),
'bool': Template('PyBool_Check($arg)'),
}

RETURN_WRAPPER = {
'Convolution*': Template('return THPWrapper_New($result, [](void* arg) { delete (Convolution*)arg; });'),
'THTensor*': Template('return THPTensor_(New)($result);'),
}

METHODS_DECLARATION = Template("""
static PyMethodDef _THCUDNN_methods[] = {
$methods
{NULL}
};

PyMethodDef* THCUDNN_methods()
{
return _THCUDNN_methods;
}
""")

WRAPPER_TEMPLATE = Template("""\
static PyObject * $name(PyObject *self, PyObject *args, PyObject *kwargs)
{
HANDLE_TH_ERRORS
int __tuplecount = args ? PyTuple_Size(args) : 0;
int __dictcount = kwargs ? PyDict_Size(kwargs) : 0;
int __argcount = __tuplecount + __dictcount;
PyObject* tensorClass = getTensorClass(args);
THCPAutoGPU __autogpu_guard = THCPAutoGPU(args);

$options
}

THPUtils_invalidArguments(args, "$readable_name", $num_options, $expected_args);
return NULL;
END_HANDLE_TH_ERRORS
}
""")

RELEASE_ARG = Template("_${name}_guard.release();")

TYPE_NAMES = {
'THTensor*': '" THPTensorStr "',
'long': 'int',
'bool': 'bool',
'int': 'int',
}

def __init__(self):
self.declarations = []

def get_type_unpack(self, arg, option):
return self.TYPE_UNPACK.get(arg['type'], None)

def get_type_check(self, arg, option):
return self.TYPE_CHECK.get(arg['type'], None)

def get_wrapper_template(self, declaration):
arg_desc = []
for option in declaration['options']:
option_desc = [self.TYPE_NAMES.get(arg['type'], arg['type']) + ' ' + arg['name']
for arg in option['arguments']
if not arg.get('ignore_check', False)]
# TODO: this should probably go to THPLongArgsPlugin
if option_desc:
arg_desc.append('({})'.format(', '.join(option_desc)))
else:
arg_desc.append('no arguments')
arg_desc.sort(key=len)
arg_desc = ['"' + desc + '"' for desc in arg_desc]
arg_str = ', '.join(arg_desc)
readable_name = declaration['python_name']
return Template(self.WRAPPER_TEMPLATE.safe_substitute(
readable_name=readable_name, num_options=len(arg_desc),
expected_args=arg_str))

def get_return_wrapper(self, option):
return self.RETURN_WRAPPER.get(option['return'], None)

def get_arg_accessor(self, arg, option):
name = arg['name']
if name == 'self':
return 'self'
elif name == 'dataType':
return 'getCudnnDataType(tensorClass)'
elif name == 'handle':
return 'getCudnnHandle()'

def process_declarations(self, declarations):
for declaration in declarations:
declaration.setdefault('python_name', '_{}'.format(declaration['name']))
declaration['name'] = 'THCUDNN_{}'.format(declaration['name'])
self.declarations.append(declaration)
for option in declaration['options']:
for arg in option['arguments']:
if arg['name'] in ['self', 'state', 'dataType', 'handle']:
arg['ignore_check'] = True
declaration['options'] = self.filter_unique_options(declaration['options'])
return declarations

def filter_unique_options(self, options):
def signature(option):
return '#'.join(arg['type'] for arg in option['arguments'] if not 'ignore_check' in arg or not arg['ignore_check'])
seen_signatures = set()
unique = []
for option in options:
sig = signature(option)
if sig not in seen_signatures:
unique.append(option)
seen_signatures.add(sig)
return unique

def preprocessor_guard(self, code, condition):
return '#if ' + condition + '\n' + code + '#endif\n'

def process_wrapper(self, code, declaration):
if 'defined_if' in declaration:
return self.preprocessor_guard(code, declaration['defined_if'])
return code

def process_all_unpacks(self, code, option):
return 'state, ' + code

def declare_methods(self):
methods = ''
for declaration in self.declarations:
extra_flags = ' | ' + declaration.get('method_flags') if 'method_flags' in declaration else ''
if not declaration.get('only_register'):
extra_flags += ' | METH_KEYWORDS'
entry = Template(' {"$python_name", (PyCFunction)$name, METH_VARARGS$extra_flags, NULL},\n').substitute(
python_name=declaration['python_name'], name=declaration['name'], extra_flags=extra_flags
)
if 'defined_if' in declaration:
entry = self.preprocessor_guard(entry, declaration['defined_if'])
methods += entry
return self.METHODS_DECLARATION.substitute(methods=methods)

def process_full_file(self, code):
return code + self.declare_methods()
3 changes: 3 additions & 0 deletions tools/cwrap/plugins/THPPlugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ class THPPlugin(CWrapPlugin):
'THLongStorage*': Template('((THPLongStorage*)$arg)->cdata'),
'THStorage*': Template('((THPStorage*)$arg)->cdata'),
'THGenerator*': Template('((THPGenerator*)$arg)->cdata'),
'THSize*': Template('THPUtils_unpackSize($arg)'),
'void*': Template('THPUtils_unpackLong($arg)'),
'long': Template('THPUtils_unpackLong($arg)'),
'int': Template('THPUtils_unpackLong($arg)'),
Expand All @@ -38,6 +39,7 @@ class THPPlugin(CWrapPlugin):
'THLongStorage*': Template('(PyObject*)Py_TYPE($arg) == THPLongStorageClass'),
'THStorage*': Template('(PyObject*)Py_TYPE($arg) == THPStorageClass'),
'THGenerator*': Template('(PyObject*)Py_TYPE($arg) == THPGeneratorClass'),
'THSize*': Template('(PyObject*)Py_TYPE($arg) == THPSizeClass'),
'void*': Template('THPUtils_checkLong($arg)'),
'long': Template('THPUtils_checkLong($arg)'),
'int': Template('THPUtils_checkLong($arg)'),
Expand Down Expand Up @@ -152,6 +154,7 @@ class THPPlugin(CWrapPlugin):
'THIndexTensor*': '" THPModuleStr "LongTensor',
'THFloatTensor*': '" THPModuleStr "FloatTensor',
'THDoubleTensor*': '" THPModuleStr "DoubleTensor',
'THSize*': 'torch.Size',
'long': 'int',
'real': '" RealStr "',
'double': 'float',
Expand Down
1 change: 1 addition & 0 deletions tools/cwrap/plugins/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,3 +57,4 @@ def process_option_code_template(self, template, option):
from .ReturnArguments import ReturnArguments
from .GILRelease import GILRelease
from .AutoGPU import AutoGPU
from .CuDNNPlugin import CuDNNPlugin
Loading
0