Skip to content

Export dynamic batch size ONNX using ONNX's DeformConv#167

Open
itskyf wants to merge 1 commit intoZhengPeng7:mainfrom
itskyf:main
Open

Export dynamic batch size ONNX using ONNX's DeformConv#167
itskyf wants to merge 1 commit intoZhengPeng7:mainfrom
itskyf:main

Conversation

@itskyf
Copy link

@itskyf itskyf commented Jan 20, 2025

This PR replaces the usage of deform_conv2d_onnx_exporter with the native DeformConv operator available in ONNX opset 19. The exported ONNX model now supports dynamic batch sizes.

Notes

@ZhengPeng7
Copy link
Owner

ZhengPeng7 commented Jan 20, 2025

Thanks a lot! I'll take time to look at it tomorrow, which really helps.
BTW, could you update the PR with no output and minimal modification in the notebook? That would be very nice for me to read and test the updated part in a clear way.

@itskyf
Copy link
Author

itskyf commented Jan 21, 2025

Sure, I've updated the notebook to reduce the modifications.

@alvarofsan
Copy link

alvarofsan commented Jan 21, 2025

Thank you so much, @itskyf, for your contribution! Have you had a chance to test whether the execution works with ONNX Runtime?

@ZhengPeng7
Copy link
Owner

截屏2025-01-21 19 47 42 Hi, @itskyf. Did you successfully export the ONNX model? I tried it but met this problem. I tried both `PyTorch==2.0.1+onnxruntime-gpu==1.18.1` and `PyTorch==2.5.1+onnxruntime-gpu==1.20.1`).

@alvarofsan
Copy link

Hi @ZhengPeng7,

I believe the issue arises because ONNX has implemented the DeformConv operator, but unfortunately, ONNX Runtime does not currently support it. As a result, any code that includes this operator cannot be executed within a Runtime Session. :/

@itskyf
Copy link
Author

itskyf commented Jan 21, 2025

@ZhengPeng7 ah, I forgot to mention that we need to also update the onnx package for opset 19.
@alfausa1 I faced the same problem. The dynamic batched model can only be used after TensorRT conversion. But since DefirmConv is an ONNX operator, I hope it will be supported in ONNXRuntime soon.

@alvarofsan
Copy link

@itskyf Could you please provide the code in which you have converted the dynamic batched model to TensorRT? Thanks in advance!

@jhwei
Copy link

jhwei commented Jan 23, 2025

Thanks for @itskyf 's PR. This is exactly what I tested, and it worked.

I have a question about this PR for @itskyf

When I tested in this way, I found the result trt engine will work as expected when the batch size used when generating is different from batch size used for inferencing. And I figured out #166 this change should be made. Do you find the same issue ?

@ZhengPeng7
Copy link
Owner

@ZhengPeng7 ah, I forgot to mention that we need to also update the onnx package for opset 19. @alfausa1 I faced the same problem. The dynamic batched model can only be used after TensorRT conversion. But since DefirmConv is an ONNX operator, I hope it will be supported in ONNXRuntime soon.

Hi, @itskyf , sorry for the late reply, just came back from the Lunar New Year holiday :)
I've upgraded the related packages as onnx==1.17.0, onnxruntime-gpu==1.20.1, onnxscript==0.1.0, which should be all the latest versions? But I still got this error (NotImplemented: [ONNXRuntimeError] : 9 : NOT_IMPLEMENTED : Could not find an implementation for DeformConv(19) node with name '/squeeze_module/squeeze_module.0/dec_att/aspp1/atrous_conv/DeformConv'):
截屏2025-02-05 22 59 34

As you said above, the DCN is still not supported in ONNXRuntime. If so, how to use the exported birefnet.onnx file?
Thanks for your kind explanation in advance!

@alvarofsan
Copy link

alvarofsan commented Feb 7, 2025

Hi @ZhengPeng7, I might be able to help.

To export with opset >19, you’ll need to update your PyTorch version to >2.4. In the provided example, it uses opset 19 for the converter and opset 20 for the entire model.

Regarding execution, you can’t run an .onnx file directly with onnxruntime by default, because the operator is not implemented yet. I believe @jhwei is referring to converting the .onnx model to a TensorRT engine for execution.

That said, I’m not sure, but maybe you can run it natively with onnxruntime if you specify the TensorRT execution provider.

@ZhengPeng7
Copy link
Owner

Hi, @alfausa1. Thanks a lot for the details :)
Yeah, currently, the suggested and default PyTorch version used in BiRefNet is 2.5.1, which should be good here.
So, it seems that if we want to use onnxruntime's session to run it, we can only use the previously employed 3rd deformConv implementation. If we only want to run the model in TensorRT, we can use the native implementation in the latest ONNX to export .onnx files.
Is my understanding right?

@alvarofsan
Copy link

alvarofsan commented Feb 7, 2025

Hi @ZhengPeng7, you’re correct.

It might be worth testing if an onnxruntime session works by specifying the TensorRT execution provider like this:
sess = ort.InferenceSession('model.onnx', providers=['TensorrtExecutionProvider', 'CUDAExecutionProvider'])

If that doesn’t work, maybe @jhwei can guide us on how to export and use a TensorRT engine, as there are different approaches that involve using CUDA libraries and low-level configurations.

I also found this new repo: onnx/onnx-tensorrt, which could be useful to test.

Sorry for all the information without testing it myself—my GPU resources are currently limited :((

@ZhengPeng7
Copy link
Owner

Thank you, alfausa1, I've tested it but more errors need to be fixed there and more libs needs to be installed. I'll take a deeper look into it when I have spare time.

@ShirasawaSama
Copy link

Haha, I'm at my wit's end. 😭😭😭 I tried adding dynamic image size input to BirefNet, but it's still stuck on ONNXRuntime not supporting DeformConv2D. Looks like I'll have to hand-code a TensorRT plugin after all. This is so frustrating.

This is my current work. I hope future developers can build upon it. “相信后人的智慧”

https://gist.github.com/ShirasawaSama/c231d83e3c24d10d4b706051c0c2c6f1

On the other hand, I didn't expect that the MPS of MacOS natively supports the DeformConv2d operator.

Meanwhile, I tried to export an ONNX model in fp16 format and found that the discrepancy was extremely large, which was quite strange:

Torch vs ONNX FP32: Max Diff=0.008335, Mean Diff=0.000908
ONNX FP32 vs FP16 : Max Diff=0.436847, Mean Diff=0.010630
[WARNING] FP32 precision loss is higher than expected!
[WARNING] FP16 precision loss is significant. Check output visually.

I can provide my code. Would anyone be willing to help me?

@ShirasawaSama
Copy link

I submitted a Pull Request for ONNXRuntime to enable support for the DeformConv2D operator on CPU, CUDA, and TensorRT.

microsoft/onnxruntime#27393

Next, I will conduct comparative tests using my build and birefnet to check for any accuracy degradation. I will post updates here as progress is made.

@ZhengPeng7
Copy link
Owner

Thanks for the PR. I'm not an expert on this, but still, many thanks and good luck to you!

@ShirasawaSama
Copy link

QQ_1771702163210

Latest progress: Dynamic batch and dynamic input size onnx conversion have been successfully implemented. Performance under CUDA (non-TensorRT) is comparable to Torch, with a difference of less than 20ms.

However, a rather awkward issue has arisen: although I submitted the deformconv implementation to the official repository, this operator is only available in opset19.

Specifying ONNX operator version 19 for export causes the nn.interpolate function to be replaced with opset19.resize. However, ONNXRuntime's CUDA implementation currently lacks this version of the resize operator, forcing a fallback to CPU execution. This results in performance degradation (roughly 1.5x slower, though it was about 3x slower before the native deformconv operator was implemented).

Next I will submit the implementation code for the CUDA resize operator version 19 to the ONNX Runtime official repository.

If anyone urgently needs to deploy this, I can provide the relevant code.

@alvarofsan
Copy link

alvarofsan commented Feb 26, 2026

Thanks for the effort @ShirasawaSama!!

In my opinion, it’s strange that the ONNX session performs similarly, or even worse than PyTorch. I would honestly expect an improvement in inference times. If I understood correctly, this could be due to the resize operator, right??

@ShirasawaSama
Copy link

Thanks for the effort @ShirasawaSama!!

In my opinion, it’s strange that the ONNX session performs similarly, or even worse than PyTorch. I would honestly expect an improvement in inference times. If I understood correctly, this could be due to the resize operator, right??

In fact, it might simply be because I haven't fully optimized performance yet. I've only ensured it runs on the GPU as much as possible. The model can still be fine-tuned to improve performance, and if I have more time, I'll continue trying to optimize it.

@ShirasawaSama
Copy link

RTX5080 1024*1024

Float32

image

FP16

image

@ShirasawaSama
Copy link

TensorRT:
image

Summary:

  • CUDA (Torch): 86.5ms
  • CUDA (ONNX): 128ms
  • TensorRT (ONNX): 26ms (3.31x faster)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Request: Support for deform_conv2d from an alternate repository to enable better ONNX export with dynamic_shape

5 participants