添加链接
link之家
链接快照平台
  • 输入网页链接,自动生成快照
  • 标签化管理网页链接

1.导出onnx

环境pycharm+onnx1.7.0+torch1.7.1+cuda11.0

使用:python3 tools/export_onnx.py --output-name yolox_s.onnx -n yolox-s -c yolox_s.pth

导出参考官网:https://github.com/Megvii-BaseDetection/YOLOX/tree/main/demo/ONNXRuntime

注意软件版本: TensorRT、cuda、cudnn各版本的匹配要求 ,onnx版本和tensorr、cuda版本要适配不然很容易失败

2.onnx转trt

环境:VS2019+trt7.1.3.4+cuda11.0

参考源码: https://github.com/shaoeric/TensorRT_ONNX_resnet18

CIFARONNX.cpp 修改如下:

#include <iostream>
#include <opencv2/opencv.hpp>
#include "NvInfer.h"
#include "NvOnnxParser.h"
//#include "NvOnnxParserRuntime.h"
#include "argsParser.h"
#include "logger.h"
#include "common.h"
//#define HEIGHT 640
//#define WIDTH 640
//#define CHANNEL 3
#define BATCH_SIZE 1
//#define NUM_CLASSES 20
#define RUN_FP16 false
#define RUN_INT8 false
using namespace std;
//地址:https://github.com/shaoeric/TensorRT_ONNX_resnet18
class TRTONNX
public:
    TRTONNX(const string& onnx_file, const string& engine_file) : m_onnx_file(onnx_file), m_engine_file(engine_file) {};
    vector<float> prepareImage(const cv::Mat& img);
    bool onnxToTRTModel(nvinfer1::IHostMemory* trt_model_stream);
    bool loadEngineFromFile();
    void doInference(const cv::Mat& img);
private:
    const string m_onnx_file;
    const string m_engine_file;
    samplesCommon::Args gArgs;
    nvinfer1::ICudaEngine* m_engine;
    bool constructNetwork(nvinfer1::IBuilder* builder, nvinfer1::INetworkDefinition* network, nvinfer1::IBuilderConfig* config, nvonnxparser::IParser* parser);
    bool saveEngineFile(nvinfer1::IHostMemory* data);
    std::unique_ptr<char[]> readEngineFile(int& length);
    int64_t volume(const nvinfer1::Dims& d);
    unsigned int getElementSize(nvinfer1::DataType t);
bool TRTONNX::saveEngineFile(nvinfer1::IHostMemory* data)
    std::ofstream file;
    file.open(m_engine_file, std::ios::binary | std::ios::out);
    cout << "writing engine file..." << endl;
    file.write((const char*)data->data(), data->size());
    cout << "save engine file done" << endl;
    file.close();
    return true;
bool TRTONNX::constructNetwork(nvinfer1::IBuilder* builder, nvinfer1::INetworkDefinition* network, nvinfer1::IBuilderConfig* config, nvonnxparser::IParser* parser)
    // 解析onnx文件
    if (!parser->parseFromFile(this->m_onnx_file.c_str(), static_cast<int>(gLogger.getReportableSeverity())))
        gLogError << "Fail to parse ONNX file" << std::endl;
        return false;
    // build the Engine
    builder->setMaxBatchSize(BATCH_SIZE);
    config->setMaxWorkspaceSize(1 << 30);
    if (RUN_FP16)
        config->setFlag(nvinfer1::BuilderFlag::kFP16);
    if (RUN_INT8)
        config->setFlag(nvinfer1::BuilderFlag::kINT8);
        samplesCommon::setAllTensorScales(network, 127.0f, 127.0f);
    samplesCommon::enableDLA(builder, config, gArgs.useDLACore);
    return true;
 * 在没有trt engine plan文件的情况下,从onnx文件构建engine,然后序列化成engine plan文件
bool TRTONNX::onnxToTRTModel(nvinfer1::IHostMemory* trt_model_stream)
    nvinfer1::IBuilder* builder = nvinfer1::createInferBuilder(gLogger.getTRTLogger());
    assert(builder != nullptr);
    const auto explicitBatch = 1U << static_cast<uint32_t>(NetworkDefinitionCreationFlag::kEXPLICIT_BATCH);
    nvinfer1::INetworkDefinition* network = builder->createNetworkV2(explicitBatch);
    nvinfer1::IBuilderConfig* config = builder->createBuilderConfig();
    nvonnxparser::IParser* parser = nvonnxparser::createParser(*network, gLogger.getTRTLogger());
    // 构建网络
    if (!constructNetwork(builder, network, config, parser))
        return false;
    m_engine = builder->buildEngineWithConfig(*network, *config);
    assert(m_engine != nullptr);
    // 验证网络构建正确
    //assert(network->getNbInputs() == 1);
    //assert(network->getInput(0)->getDimensions().nbDims == 4);
    //assert(network->getNbOutputs() == 1);
    //assert(network->getOutput(0)->getDimensions().nbDims == 1);
    // 序列化
    trt_model_stream = m_engine->serialize();
    nvinfer1::IHostMemory* data = m_engine->serialize();
    saveEngineFile(data);
    parser->destroy();
    network->destroy();
    builder->destroy();
    //m_engine->destroy();
int main()
    string onnx_file = "D:/vs_projects/onnx2trt/yolox_s.onnx";
    string engine_file = "D:/vs_projects/onnx2trt/yolox_s.trt";
    IHostMemory* trt_model_stream{ nullptr };
    TRTONNX trt_onnx(onnx_file, engine_file);
    // 打开文件,是否存在engine plan
    trt_onnx.onnxToTRTModel(trt_model_stream);
View Code

 由于只进行转换不推理,所以去掉了推理部分的代码,

    同时将其中的一行代码 

nvinfer1::INetworkDefinition* network = builder->createNetwork();
    const auto explicitBatch = 1U << static_cast<uint32_t>(NetworkDefinitionCreationFlag::kEXPLICIT_BATCH);
    nvinfer1::INetworkDefinition* network = builder->createNetworkV2(explicitBatch);

否则会报错:

这里参考了tensorrt7.1.3.4里的sampleOnnxMNIST

3.yolox tensorrt推理

   这里参考:YOLOX Window10 TensorRT 全面部署教程中的windows部署部分,根据自己实际情况配置cuda和tensorrt

   自己直接运行报错:

   参考:解锁新姿势-使用TensorRT部署pytorch模型

   在yolox里执行

import onnxruntime as ort
ort_session = ort.InferenceSession('yolox_s.onnx')
input = ort_session.get_inputs()[0].name
ouput = ort_session.get_outputs()[0].name
print(input)
print(ouput)
images
output
const char* INPUT_BLOB_NAME = "input_0";
const char* OUTPUT_BLOB_NAME = "output_0";
const char* INPUT_BLOB_NAME = "images";
const char* OUTPUT_BLOB_NAME = "output";

问题解决。。。