8000 fix: Fix when TRT prunes away an output · pytorch/TensorRT@9465e1d · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

Commit 9465e1d

Browse files
committed
fix: Fix when TRT prunes away an output
Signed-off-by: Dheeraj Peri <peri.dheeraj@gmail.com>
1 parent 4d2cb14 commit 9465e1d

File tree

3 files changed

+36
-3
lines changed

3 files changed

+36
-3
lines changed

core/conversion/conversion.cpp

+26-3
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
1+
<<<<<<< HEAD
12
#include <sstream>
23

4+
=======
5+
>>>>>>> 367dd7bb... chore: refactor applyIdentityOp
36
#include "core/conversion/conversion.h"
7+
#include <torch/torch.h>
8+
#include <sstream>
49
#include "core/conversion/conversionctx/ConversionCtx.h"
510
#include "core/conversion/converters/converters.h"
611
#include "core/conversion/evaluators/evaluators.h"
@@ -234,10 +239,28 @@ void MarkOutputs(ConversionCtx* ctx, at::ArrayRef<const torch::jit::Value*> outp
234239
}
235240
}
236241
} else {
237-
std::string name = std::string("output_") + std::to_string(ctx->num_outputs);
242+
bool setOutput = false;
243+
auto num_inputs = ctx->net->getNbInputs();
238244
auto out_tensor = it->second;
239-
out_tensor->setName(name.c_str());
240-
ctx->net->markOutput(*out_tensor);
245+
std::string name = std::string("output_") + std::to_string(ctx->num_outputs);
246+
247+
// Check if the output tensor is one of the inputs to the network. If so, apply an identity layer to it.
248+
for (int64_t i = 0; i < num_inputs; i++) {
249+
if (out_tensor == ctx->net->getInput(i)) {
250+
LOG_DEBUG(
251+
"One of the inputs named "
252+
<< ctx->net->getInput(i)->getName()
253+
<< " to the network is marked as an output tensor. Applying an identity layer and marking this tensor as output");
254+
auto id_out_tensor = converters::applyIdentityOp(ctx, out_tensor, name);
255+
ctx->net->markOutput(*id_out_tensor);
256+
setOutput = true;
257+
}
258+
}
259+
260+
if (!setOutput) {
261+
out_tensor->setName(name.c_str());
262+
ctx->net->markOutput(*out_tensor);
263+
}
241264
LOG_INFO(
242265
ctx->logger, "Marking Output " << out->debugName() << " named " << name << " in engine (ctx.MarkOutput)");
243266
ctx->num_outputs += 1;

core/conversion/converters/converter_util.cpp

+7
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,13 @@ nvinfer1::ILayer* add_elementwise(
121121
return ele;
122122
}
123123

124+
nvinfer1::ITensor* applyIdentityOp(ConversionCtx* ctx, nvinfer1::ITensor* tensor, const std::string& tensor_name) {
125+
auto id_layer = ctx->net->addIdentity(*tensor);
126+
auto id_out_tensor = id_layer->getOutput(0);
127+
id_out_tensor->setName(tensor_name.c_str());
128+
return id_out_tensor;
129+
}
130+
124131
nvinfer1::ITensor* castITensor(ConversionCtx* ctx, nvinfer1::ITensor* tensor, nvinfer1::DataType dtype) {
125132
if (tensor->getType() != dtype) {
126133
std::ostringstream tensor_id;

core/conversion/converters/converter_util.h

+3
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,9 @@ nvinfer1::ILayer* add_elementwise(
4141
nvinfer1::ITensor* other,
4242
const std::string& name);
4343

44+
// Apply an identity operation on a tensor. Used in the case where an input is an output to a network.
45+
nvinfer1::ITensor* applyIdentityOp(ConversionCtx* ctx, nvinfer1::ITensor* tensor, const std::string& name);
46+
4447
// If an ITensor is of a type not dtype, add an Identity layer to cast it to dtype
4548
nvinfer1::ITensor* castITensor(ConversionCtx* ctx, nvinfer1::ITensor* tensor, nvinfer1::DataType dtype);
4649

0 commit comments

Comments
 (0)
0