-
Notifications
You must be signed in to change notification settings - Fork 3.8k
Description
Often times, it is useful to have reduction axes with length 1 - e.g. to deal with a 1x1 kernel for a conv2d. However, when using tensorize in this case where the outer reduction axis has length one, you get an error like the following:
TVMError: Traceback (most recent call last):
36: TVMFuncCall
35: _ZN3tvm7runtime13PackedFun
34: tvm::runtime::TypedPackedFunc<tvm::IRModule (tvm::te::Schedule, tvm::runtime::Array<tvm::runtime::ObjectRef, void> const&, tvm::runtime::String const&, tvm::runtime::Map<tvm::te::Tensor, tvm::tir::Buffer, void, void> const&, bool)>::AssignTypedLambda<tvm::__mk_TVM16::{lambda(tvm::te::Schedule, tvm::runtime::Array<tvm::runtime::ObjectRef, void> const&, tvm::runtime::String const&, tvm::runtime::Map<tvm::te::Tensor, tvm::tir::Buffer, void, void> const&, bool)#1}>(tvm::__mk_TVM16::{lambda(tvm::te::Schedule, tvm::runtime::Array<tvm::runtime::ObjectRef, void> const&, tvm::runtime::String const&, tvm::runtime::Map<tvm::te::Tensor, tvm::tir::Buffer, void, void> const&, bool)#1}, std::string)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}::operator()(tvm::runtime::TVMArgs, tvm::runtime::TVMArgs const&) const
33: tvm::LowerSchedule(tvm::te::Schedule, tvm::runtime::Array<tvm::runtime::ObjectRef, void> const&, std::string const&, std::unordered_map<tvm::te::Tensor, tvm::tir::Buffer, std::hash<tvm::te::Tensor>, std::equal_to<tvm::te::Tensor>, std::allocator<std::pair<tvm::te::Tensor const, tvm::tir::Buffer> > > const&, bool)
32: tvm::LowerWithPassList(tvm::IRModule, tvm::runtime::Array<tvm::transform::Pass, void>)
31: tvm::transform::Pass::operator()(tvm::IRModule) const
30: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
29: tvm::transform::SequentialNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
28: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
27: tvm::tir::transform::PrimFuncPassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
26: _ZN3tvm7runtime13PackedFuncObj9ExtractorINS0_16PackedFuncSubObjIZNS0_15TypedPackedFuncIFNS_3tir8PrimFuncES6_NS_8IRModuleENS_9transform11PassContextEEE17AssignTypedLambdaIZNS5_9transform14StorageFlattenEibEUlS6_S7_S9_E_EEvT_EUlRKNS0_7TVMArgsEPNS0_11TVMRetValueEE_EEE4CallEPKS1_SG_SK_
25: tvm::tir::StorageFlatten(tvm::tir::PrimFunc, int, bool)
24: tvm::transform::Pass::operator()(tvm::IRModule) const
23: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
22: tvm::transform::SequentialNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
21: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
20: tvm::tir::transform::PrimFuncPassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
19: _ZN3tvm7runtime13PackedFun
18: tvm::runtime::TypedPackedFunc<tvm::tir::PrimFunc (tvm::tir::PrimFunc, tvm::IRModule, tvm::transform::PassContext)>::AssignTypedLambda<tvm::tir::BufferBindUnwrapper::Pass()::{lambda(tvm::tir::PrimFunc, tvm::IRModule, tvm::transform::PassContext)#1}>(tvm::tir::BufferBindUnwrapper::Pass()::{lambda(tvm::tir::PrimFunc, tvm::IRModule, tvm::transform::PassContext)#1})::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}::operator()(tvm::runtime::TVMArgs const, tvm::runtime::TVMRetValue) const
17: tvm::tir::StmtMutator::VisitStmt(tvm::tir::Stmt const&)
16: tvm::tir::StmtFunctor<tvm::tir::Stmt (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
15: _ZZN3tvm3tir11StmtFunctorI
14: tvm::tir::StmtMutator::VisitStmt_(tvm::tir::ForNode const*)
13: tvm::tir::StmtMutator::VisitStmt(tvm::tir::Stmt const&)
12: tvm::tir::StmtFunctor<tvm::tir::Stmt (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
11: _ZZN3tvm3tir11StmtFunctorI
10: tvm::tir::StmtMutator::VisitStmt_(tvm::tir::ForNode const*)
9: tvm::tir::StmtMutator::VisitStmt(tvm::tir::Stmt const&)
8: tvm::tir::StmtFunctor<tvm::tir::Stmt (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
7: _ZZN3tvm3tir11StmtFunctorI
6: tvm::tir::BufferBindUnwrapper::VisitStmt_(tvm::tir::AttrStmtNode const*)
5: tvm::tir::BufferBindUnwrapper::HandleBufferBindScope(tvm::tir::AttrStmtNode const*)
4: tvm::tir::ArgBinder::BindBuffer(tvm::tir::Buffer const&, tvm::tir::Buffer const&, std::string const&, bool)
3: tvm::tir::ArgBinder::BindArray(tvm::runtime::Array<tvm::PrimExpr, void> const&, tvm::runtime::Array<tvm::PrimExpr, void> const&, std::string const&)
2: tvm::tir::ArgBinder::Bind_(tvm::PrimExpr const&, tvm::PrimExpr const&, std::string const&, bool)
1: tvm::tir::BinderAddAssert(tvm::arith::Analyzer*, tvm::PrimExpr, std::string const&, std::vector<tvm::tir::Stmt, std::allocator<tvm::tir::Stmt> >*)
0: _ZN3tvm7runtime6deta
File "/workspace/tvm/src/tir/transforms/arg_binder.cc", line 40
TVMError: Bind have an unmet assertion: (bool)0, on argument foobar.strides[1]
Note that this bug only occurs when the outer reduction axis has length one - all the others are fine. The issue occurs on the latest version of TVM, and on older versions like 0.9.0.
I've found a hack to work around this bug, and have used it in some of my PRs:
tvm/python/tvm/topi/arm_cpu/mprofile/dsp/tensordot_conv2ds.py
Lines 158 to 162 in 981b1bd
| # TVM has a really strange bug where the outer reduction axis (kh_i) having length 1 causes the | |
| # decl_buffer strides check to fail. height_stride is a dark magic workaround for this. | |
| height_stride = in_channels * padded_w if kernel_h > 1 else in_channels | |
| jump = (padded_w - kernel_w) * in_channels | |
| tensordot_params = (in_dtype, kernel_h, jump, kernel_w * in_channels, suffix) |
The hack is pretty gross though, so it would be nice to get a fix.
Steps to reproduce
The easiest way to reproduce this is to use a Colab notebook:
https://colab.research.google.com/drive/1Y9LXBdQxQD-FjNbW6cChuExse6IvMdGZ
You can also reproduce it using this script: bug_reproduction.py.txt