8000 rocAL Reinterpret cast by fiona-gladwin · Pull Request #7 · fiona-gladwin/rocAL · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

rocAL Reinterpret cast #7

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: develop
Choose a base branch
8000
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
8000 Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions rocAL/include/api/rocal_api_augmentation.h
Original file line number Diff line number Diff line change
Expand Up @@ -1144,4 +1144,10 @@ extern "C" RocalTensor ROCAL_API_CALL rocalSpectrogram(RocalContext context,
RocalTensorLayout output_layout = ROCAL_NFT,
RocalTensorOutputType output_datatype = ROCAL_FP32);

extern "C" RocalTensor ROCAL_API_CALL rocalReinterpretCast(
RocalContext p_context,
RocalTensor p_input,
RocalTensorOutputType output_datatype,
bool is_output = false);

#endif // MIVISIONX_ROCAL_API_AUGMENTATION_H
13 changes: 13 additions & 0 deletions rocAL/include/pipeline/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,19 @@ class TensorInfo {
_data_size = (_data_size / _data_type_size);
_data_size *= data_type_size();
}
void reset_data_type(RocalTensorDataType data_type) {
if (_data_type == data_type)
return;
_data_type = data_type;
if (_dims.back() % data_type_size() != 0) {
THROW("The innermost dimension is not divisible by the requested data type, Data type change cannot be done")
}
_dims[_num_of_dims - 1] = _dims.back() / data_type_size();
_data_size = (_data_size / _data_type_size);
_data_size *= data_type_size();
modify_strides();
set_max_shape();
}
void get_modified_dims_from_layout(RocalTensorlayout input_layout, RocalTensorlayout output_layout, std::vector<size_t>& new_dims) {
std::vector<size_t> dims_mapping;
if (input_layout == RocalTensorlayout::NHWC && output_layout == RocalTensorlayout::NCHW) {
Expand Down
26 changes: 26 additions & 0 deletions rocAL/source/api/rocal_api_augmentation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2156,6 +2156,32 @@ rocalNop(
return output;
}

RocalTensor ROCAL_API_CALL
rocalReinterpretCast(
RocalContext p_context,
RocalTensor p_input,
RocalTensorOutputType output_datatype,
bool is_output) {
Tensor* output = nullptr;
if ((p_context == nullptr) || (p_input == nullptr)) {
ERR("Invalid ROCAL context or invalid input tensor")
return output;
}
auto context = static_cast<Context*>(p_context);
auto input = static_cast<Tensor*>(p_input);
try {
RocalTensorDataType op_tensor_datatype = static_cast<RocalTensorDataType>(output_datatype);
auto output_info = input->info();
output_info.reset_data_type(op_tensor_datatype);
output = context->master_graph->create_tensor(output_info, is_output);
context->master_graph->add_node<CopyNode>({input}, {output});
} catch (const std::exception& e) {
context->capture_error(e.what());
ERR(e.what())
}
return output;
}

RocalTensor ROCAL_API_CALL
rocalPreEmphasisFilter(RocalContext p_context,
RocalTensor p_input,
Expand Down
8 changes: 8 additions & 0 deletions rocAL_pybind/amd/rocal/fn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1112,3 +1112,11 @@ def spectrogram(*inputs, bytes_per_sample_hint = [0], center_windows = True, lay
"power": power, "nfft": nfft, "window_length": window_length, "window_step": window_step, "output_layout": layout, "output_dtype": output_dtype}
spectrogram_output = b.spectrogram(Pipeline._current_pipeline._handle, *(kwargs_pybind.values()))
return (spectrogram_output)

def reinterpret(*inputs, output_dtype):
"""
Produces tensor output reinterpreted with new data type
"""
kwargs_pybind = {"input": inputs[0], "output_datatype" : output_dtype, "is_output": False }
reinterpret_output = b.reinterpret(Pipeline._current_pipeline._handle, *(kwargs_pybind.values()))
return (reinterpret_output)
2 changes: 2 additions & 0 deletions rocAL_pybind/rocal_pybind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -731,5 +731,7 @@ PYBIND11_MODULE(rocal_pybind, m) {
py::return_value_policy::reference);
m.def("spectrogram", &rocalSpectrogram,
py::return_value_policy::reference);
m.def("reinterpret", &rocalReinterpretCast,
py::return_value_policy::reference);
}
} // namespace rocal
0