8000 terminate called after throwing an instance of 'c10::Error' what(): isTuple() INTERNAL ASSERT FAILED at "/home/wenda/libtorch/include/ATen/core/ivalue_inl.h":927, please report a bug to PyTorch. Expected Tuple but got GenericList · Issue #53895 · pytorch/pytorch · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

terminate called after throwing an instance of 'c10::Error' what(): isTuple() INTERNAL ASSERT FAILED at "/home/wenda/libtorch/include/ATen/core/ivalue_inl.h":927, please report a bug to PyTorch. Expected Tuple but got GenericList #53895

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
vedics opened this issue Mar 12, 2021 · 14 comments
Labels
module: cpp Related to C++ API module: crash Problem manifests as a hard crash, as opposed to a RuntimeError needs reproduction Someone else needs to try reproducing the issue given the instructions. No action needed from user triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@vedics
Copy link
vedics commented Mar 12, 2021
terminate called after throwing an instance of 'c10::Error'
  what():  isTuple() INTERNAL ASSERT FAILED at "/home/wenda/libtorch/include/ATen/core/ivalue_inl.h":927, please report a bug to PyTorch. Expected Tuple but got GenericList
Exception raised from toTuple at /home/wenda/libtorch/include/ATen/core/ivalue_inl.h:927 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >) + 0x69 (0x7f38baf4cb89 in /home/wenda/libtorch/lib/libc10.so)
frame #1: main + 0xd41 (0x5595dbce8d11 in ./YOLOv5LibTorch)
frame #2: __libc_start_main + 0xe7 (0x7f386bf36bf7 in /lib/x86_64-linux-gnu/libc.so.6)
frame #3: _start + 0x2a (0x5595dbcea11a in ./YOLOv5LibTorch)

cc @ezyang @gchanan @zou3519 @bdhirsh @jbschlosser @anjali411 @yf225 @glaringlee

@vedics
Copy link
Author
vedics commented Mar 12, 2021

This error occurred when I deployed yolov5 C + + with libtorch. Who can help me?

@ngimel
Copy link
Collaborator
ngimel commented Mar 13, 2021
8000

Can you provide self-contained code reproducing the error?

@vedics
Copy link
Author
vedics commented Mar 13, 2021

Can you provide self-contained code reproducing the error?

#include <opencv2/opencv.hpp>
#include <torch/script.h>
#include <algorithm>
#include <iostream>
#include <time.h>


std::vector<torch::Tensor> non_max_suppression(torch::Tensor preds, float score_thresh=0.5, float iou_thresh=0.5)
{
        
        
        std::vector<torch::Tensor> output;
        
        for (size_t i=0; i < preds.sizes()[0]; ++i)
        {
            torch::Tensor pred = preds.select(0, i);
            
            // Filter by scores
            torch::Tensor scores = pred.select(1, 4) * std::get<0>( torch::max(pred.slice(1, 5, pred.sizes()[1]), 1));
           
            pred = torch::index_select(pred, 0, torch::nonzero(scores > score_thresh).select(1, 0));
            if (pred.sizes()[0] == 0) continue;

            // (center_x, center_y, w, h) to (left, top, right, bottom)
            pred.select(1, 0) = pred.select(1, 0) - pred.select(1, 2) / 2;
            pred.select(1, 1) = pred.select(1, 1) - pred.select(1, 3) / 2;
            pred.select(1, 2) = pred.select(1, 0) + pred.select(1, 2);
            pred.select(1, 3) = pred.select(1, 1) + pred.select(1, 3);

            // Computing scores and classes
            std::tuple<torch::Tensor, torch::Tensor> max_tuple = torch::max(pred.slice(1, 5, pred.sizes()[1]), 1);
            pred.select(1, 4) = pred.select(1, 4) * std::get<0>(max_tuple);
            pred.select(1, 5) = std::get<1>(max_tuple);

            torch::Tensor  dets = pred.slice(1, 0, 6);

            torch::Tensor keep = torch::empty({dets.sizes()[0]});
            torch::Tensor areas = (dets.select(1, 3) - dets.select(1, 1)) * (dets.select(1, 2) - dets.select(1, 0));
            std::tuple<torch::Tensor, torch::Tensor> indexes_tuple = torch::sort(dets.select(1, 4), 0, 1);
            torch::Tensor v = std::get<0>(indexes_tuple);
            torch::Tensor indexes = std::get<1>(indexes_tuple);
            int count = 0;
            while (indexes.sizes()[0] > 0)
            {
                keep[count] = (indexes[0].item().toInt());
                count += 1;

                // Computing overlaps
                torch::Tensor lefts = torch::empty(indexes.sizes()[0] - 1);
                torch::Tensor tops = torch::empty(indexes.sizes()[0] - 1);
                torch::Tensor rights = torch::empty(indexes.sizes()[0] - 1);
                torch::Tensor bottoms = torch::empty(indexes.sizes()[0] - 1);
                torch::Tensor widths = torch::empty(indexes.sizes()[0] - 1);
                torch::Tensor heights = torch::empty(indexes.sizes()[0] - 1);
                for (size_t i=0; i<indexes.sizes()[0] - 1; ++i)
                {
                    lefts[i] = std::max(dets[indexes[0]][0].item().toFloat(), dets[indexes[i + 1]][0].item().toFloat());
                    tops[i] = std::max(dets[indexes[0]][1].item().toFloat(), dets[indexes[i + 1]][1].item().toFloat());
                    rights[i] = std::min(dets[indexes[0]][2].item().toFloat(), dets[indexes[i + 1]][2].item().toFloat());
                    bottoms[i] = std::min(dets[indexes[0]][3].item().toFloat(), dets[indexes[i + 1]][3].item().toFloat());
                    widths[i] = std::max(float(0), rights[i].item().toFloat() - lefts[i].item().toFloat());
                    heights[i] = std::max(float(0), bottoms[i].item().toFloat() - tops[i].item().toFloat());
                }
                torch::Tensor overlaps = widths * heights;
                
                // FIlter by IOUs
                torch::Tensor ious = overlaps / (areas.select(0, indexes[0].item().toInt()) + torch::index_select(areas, 0, indexes.slice(0, 1, indexes.sizes()[0])) - overlaps);
               
                indexes = torch::index_select(indexes, 0, torch::nonzero(ious <= iou_thresh).select(1, 0) + 1);
            }
            keep = keep.toType(torch::kInt64);
            output.push_back(torch::index_select(dets, 0, keep.slice(0, 0, count)));
        }
       
        return output;
}


int main()
{
    // Loading  Module
    torch::DeviceType device_type = at::kCUDA;
    device_type = at::kCUDA;
    
    torch::jit::script::Module module = torch::jit::load("../last.torchscript.pt");
    module.to(device_type);
    
    std::vector<std::string> classnames;
	std::ifstream f("../new_coco.names");
	std::string name = "";
	while (std::getline(f, name))
	{
		classnames.push_back(name);
	}

    cv:: VideoCapture cap = cv::VideoCapture(0);
    cap.set(cv::CAP_PROP_FRAME_WIDTH, 1920);
    cap.set(cv::CAP_PROP_FRAME_HEIGHT, 1080);
    cv::Mat frame, img;
    while(cap.isOpened())
    {
        clock_t start = clock();
	    cap.read(frame);
        if(frame.empty())
        {
           std::cout << "Read frame failed!" << std::endl;
           break;
        }
        
        // Preparing input tensor
        cv::resize(frame, img, cv::Size(640, 640));
        //cv::cvtColor(img, img, cv::COLOR_BGR2RGB);
        torch::Tensor imgTensor = torch::from_blob(img.data, {img.rows, img.cols,3},torch::kByte);
        
        imgTensor = imgTensor.permute({2,0,1});
        imgTensor = imgTensor.toType(torch::kFloat);
        imgTensor = imgTensor.div(255);
        imgTensor = imgTensor.unsqueeze(0);
        imgTensor = imgTensor.to(device_type);
        // preds: [?, 15120, 9]
        torch::Tensor preds = module.forward({imgTensor}).toTuple()->elements()[0].toTensor();
        std::vector<torch::Tensor> dets = non_max_suppression(preds, 0.4, 0.5);
        if (dets.size() > 0)
        {
            // Visualize result
            for (size_t i=0; i < dets[0].sizes()[0]; ++ i)
            {
                float left = dets[0][i][0].item().toFloat() * frame.cols / 640;
                float top = dets[0][i][1].item().toFloat() * frame.rows / 324;
                float right = dets[0][i][2].item().toFloat() * frame.cols / 640;
                float bottom = dets[0][i][3].item().toFloat() * frame.rows / 324;
                float score = dets[0][i][4].item().toFloat();
                int classID = dets[0][i][5].item().toInt();

				cv::rectangle(frame, cv::Rect(left, top, (right - left), (bottom - top)), cv::Scalar(0, 255, 0), 2);

				cv::putText(frame,
					classnames[classID] + ": " + cv::format("%.2f", score),
					cv::Point(left, top),
					cv::FONT_HERSHEY_SIMPLEX, (right - left) / 200, cv::Scalar(0, 255, 0), 2);
            }
        }
        cv::putText(frame, "FPS: " + std::to_string(int(1e7 / (clock() - start))),
            cv::Point(50, 50),
            cv::FONT_HERSHEY_SIMPLEX, 1, cv::Scalar(0, 255, 0), 2);
        cv::imshow("", frame);
        cv::waitKey(67);
        //if(cv::waitKey(67)== 27) break;
    }
    return 0;
}

This is the C + + code I tested

@vedics
Copy link
Author
vedics commented Mar 13, 2021

Can you provide self-contained code reproducing the error?

inline c10::intrusive_ptr<ivalue::Tuple> IValue::toTuple() && { AT_ASSERT(isTuple(), "Expected Tuple but got ", tagKind()); return moveToIntrusivePtr<ivalue::Tuple>(); }

This is the code of the wrong place.

@ngimel
Copy link
Collaborator
ngimel commented Mar 15, 2021

Sorry, this is not a reproducible sample, it requires files (e.g. last.torchscript.pt) that we don't have access to. Can you refine the problem down to the problematic call so that it can be reproducible with a shorter script that doesn't require extra files?

@ngimel ngimel added the needs reproduction Someone else needs to try reproducing the issue given the instructions. No action needed from user label Mar 15, 2021
@jbschlosser jbschlosser added high priority module: cpp Related to C++ API module: crash Problem manifests as a hard crash, as opposed to a RuntimeError labels Mar 15, 2021
@glaringlee
Copy link
Contributor

Are u using the same version of pytorch/libtorch to save/loading .pt file?

@jbschlosser jbschlosser added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Mar 15, 2021
@JXFOnestep
Copy link

I met the same question, have you fixed it?

@gigadeplex
Copy link

Are u using the same version of pytorch/libtorch to save/loading .pt file?

I got "Expected Tuple but got String" and this fixed it for me haha. Apparently I used 1.12.0 (source) with 1.11.0. Thx

@xiaomafei
Copy link

I met the same question, have you fixed it?

@mhbassel
Copy link
mhbassel commented Sep 1, 2022

Are u using the same version of pytorch/libtorch to save/loading .pt file?

That solved my issue, too.
I was scripting and saving the model using a version different from the one used in Triton backend when loading the model.

@lmw0320
Copy link
lmw0320 commented Sep 27, 2022

Are u using the same version of pytorch/libtorch to save/loading .pt file?

I got "Expected Tuple but got String" and this fixed it for me haha. Apparently I used 1.12.0 (source) with 1.11.0. Thx

Hi, Do you mean this error can be fixed by update the version of pytorch??

@Hideman85
Copy link

I'm having same troubles here. If you want a reproduction, I'm trying to run stable diffusion on cpu (do not have gpu on this laptop).

System: Ubuntu 23.04
Libs: python3-torch/lunar,now 1.13.1+dfsg-3 amd64 [installed], libtorch1.13/lunar,now 1.13.1+dfsg-3 amd64 [installed,automatic]
Repo: https://github.com/Stability-AI/stablediffusion
Cmd: python3 scripts/txt2img.py --prompt "A colorful nebula" --ckpt ./checkpoints/v2-1_768.ckpt --config configs/stable-diffusion/v2-inference-v.yaml --H 768 --W 768

Note: I do not use conda, everything from system lib or .local/lib/python3.11/site-packages

Stack trace
terminate called after throwing an instance of 'c10::Error'
  what():  Type c10::intrusive_ptr<ConvPackedParamsBase<2>, c10::detail::intrusive_target_default_null_type<ConvPackedParamsBase<2> > > could not be converted to any of the known types.
Exception raised from operator() at ./aten/src/ATen/core/jit_type.h:1751 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >) + 0xa2 (0x7f376c1028a2 in /lib/x86_64-linux-gnu/libc10.so.1.13)
frame #1: c10::detail::torchCheckFail(char const*, char const*, unsigned int, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&) + 0x78 (0x7f376c0ce89a in /lib/x86_64-linux-gnu/libc10.so.1.13)
frame #2: <unknown function> + 0x15792a2 (0x7f37661792a2 in /lib/x86_64-linux-gnu/libtorch_cpu.so.1.13)
frame #3: <unknown function> + 0xf4327a (0x7f3765b4327a in /lib/x86_64-linux-gnu/libtorch_cpu.so.1.13)
frame #4: c10::detail::infer_schema::make_function_schema(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >&&, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >&&, c10::ArrayRef<c10::detail::infer_schema::ArgumentDef>, c10::ArrayRef<c10::detail::infer_schema::ArgumentDef>) + 0x74 (0x7f3765b436b4 in /lib/x86_64-linux-gnu/libtorch_cpu.so.1.13)
frame #5: c10::detail::infer_schema::make_function_schema(c10::ArrayRef<c10::detail::infer_schema::ArgumentDef>, c10::ArrayRef<c10::detail::infer_schema::ArgumentDef>) + 0xaa (0x7f3765b4462a in /lib/x86_64-linux-gnu/libtorch_cpu.so.1.13)
frame #6: <unknown function> + 0x159ddff (0x7f376619ddff in /lib/x86_64-linux-gnu/libtorch_cpu.so.1.13)
frame #7: <unknown function> + 0x1586d6f (0x7f3766186d6f in /lib/x86_64-linux-gnu/libtorch_cpu.so.1.13)
frame #8: <unknown function> + 0x1586fab (0x7f3766186fab in /lib/x86_64-linux-gnu/libtorch_cpu.so.1.13)
frame #9: <unknown function> + 0xcb313d (0x7f37658b313d in /lib/x86_64-linux-gnu/libtorch_cpu.so.1.13)
frame #10: <unknown function> + 0xc59400 (0x7f3765859400 in /lib/x86_64-linux-gnu/libtorch_cpu.so.1.13)
frame #11: <unknown function> + 0x533e (0x7f37aa2d033e in /lib64/ld-linux-x86-64.so.2)
frame #12: <unknown function> + 0x5428 (0x7f37aa2d0428 in /lib64/ld-linux-x86-64.so.2)
frame #13: _dl_catch_exception + 0x102 (0x7f37aa2cc562 in /lib64/ld-linux-x86-64.so.2)
frame #14: <unknown function> + 0xc346 (0x7f37aa2d7346 in /lib64/ld-linux-x86-64.so.2)
frame #15: _dl_catch_exception + 0x7d (0x7f37aa2cc4dd in /lib64/ld-linux-x86-64.so.2)
frame #16: <unknown function> + 0xc6bc (0x7f37aa2d76bc in /lib64/ld-linux-x86-64.so.2)
frame #17: <unknown function> + 0x8abec (0x7f37a9e8abec in /lib/x86_64-linux-gnu/libc.so.6)
frame #18: _dl_catch_exception + 0x7d (0x7f37aa2cc4dd in /lib64/ld-linux-x86-64.so.2)
frame #19: <unknown function> + 0x1603 (0x7f37aa2cc603 in /lib64/ld-linux-x86-64.so.2)
frame #20: <unknown function> + 0x8a6bf (0x7f37a9e8a6bf in /lib/x86_64-linux-gnu/libc.so.6)
frame #21: dlopen + 0x71 (0x7f37a9e8aca1 in /lib/x86_64-linux-gnu/libc.so.6)
frame #22: <unknown function> + 0x16bcd (0x7f37a8ecbbcd in /usr/lib/python3.11/lib-dynload/_ctypes.cpython-311-x86_64-linux-gnu.so)
frame #23: python3() [0x520440]
<omitting python frames>
frame #27: python3() [0x541bf7]
frame #31: python3() [0x56d152]
frame #32: python3() [0x512c27]
frame #35: python3() [0x51e97f]
frame #37: python3() [0x447cc4]
frame #40: python3() [0x56d152]
frame #41: python3() [0x512c27]
frame #44: python3() [0x51e97f]
frame #46: python3() [0x447cc4]
frame #49: python3() [0x56d152]
frame #50: python3() [0x512c27]
frame #53: python3() [0x51e97f]
frame #55: python3() [0x447cc4]
frame #56: python3() [0x57ce3f]
frame #57: python3() [0x512c27]
frame #60: python3() [0x51e97f]
frame #62: python3() [0x447cc4]

Aborted (core dumped)

I'm not an expert of python and even less of pytorch so I would appreciate some help here 🙏

@mhbassel
Copy link

Hi @Hideman85, can you make sure you are using the correct version of PyTorch (1.12.1) and Torchvision (0.13.1) and then give it a try.
Also I recommend using a virtual environment (if not using it already) and install the Repo requirements in it. Maybe not specifically conda, you can create it with Python by running something like:

python -m venv PATH_TO_VENV
# Activate
source PATH_TO_VENV/bin/activate

@mckay-w
Copy link
mckay-w commented Apr 17, 2025

I met the same question. I can save and load in python, but not in c++
std::exception: Failed to load model. Details: terminate called after throwing an instance of 'c10::Error'
what(): isTuple() INTERNAL ASSERT FAILED at "../aten/src/ATen/core/ivalue_inl.h":931, please report a bug to PyTorch. Expected Tuple but got String
Exception raised from toTuple at ../aten/src/ATen/core/ivalue_inl.h:931 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::__cxx11::basic_string<char, std::char_traits, std::allocator >) + 0x69 (0x77b6a4a50b89 in /workspace/home/wch/NH-Rep/code/IsoSurfacing/libtorch/lib/libc10.so)
frame #1: + 0x311e5d4 (0x77b696b1e5d4 in /workspace/home/wch/NH-Rep/code/IsoSurfacing/libtorch/lib/libtorch_cpu.so)
frame #2: + 0x31212e5 (0x77b696b212e5 in /workspace/home/wch/NH-Rep/code/IsoSurfacing/libtorch/lib/libtorch_cpu.so)
frame #3: torch::jit::SourceRange::highlight(std::ostream&) const + 0x3c (0x77b69465782c in /workspace/home/wch/NH-Rep/code/IsoSurfacing/libtorch/lib/libtorch_cpu.so)
frame #4: torch::jit::ErrorReport::what() const + 0x2df (0x77b69463831f in /workspace/home/wch/NH-Rep/code/IsoSurfacing/libtorch/lib/libtorch_cpu.so)
frame #5: + 0x1a57a (0x5d8f441c657a in /usr/myapp/ISG)
frame #6: + 0x29d90 (0x77b6a4800d90 in /lib/x86_64-linux-gnu/libc.so.6)
frame #7: __libc_start_main + 0x80 (0x77b6a4800e40 in /lib/x86_64-linux-gnu/libc.so.6)
frame #8: + 0x229a5 (0x5d8f441ce9a5 in /usr/myapp/ISG)

env:
python 3.10:
torch 2.6.0
torchvision 0.21.0
lbtorch:
2.6.0+cu126

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: cpp Related to C++ API module: crash Problem manifests as a hard crash, as opposed to a RuntimeError needs reproduction Someone else needs to try reproducing the issue given the instructions. No action needed from user triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

0