Skip to content

[Bug] Outer reduction axis having length 1 in TE causes strides assertion failure #13010

@guberti

Description

@guberti

Often times, it is useful to have reduction axes with length 1 - e.g. to deal with a 1x1 kernel for a conv2d. However, when using tensorize in this case where the outer reduction axis has length one, you get an error like the following:

TVMError: Traceback (most recent call last):
  36: TVMFuncCall
  35: _ZN3tvm7runtime13PackedFun
  34: tvm::runtime::TypedPackedFunc<tvm::IRModule (tvm::te::Schedule, tvm::runtime::Array<tvm::runtime::ObjectRef, void> const&, tvm::runtime::String const&, tvm::runtime::Map<tvm::te::Tensor, tvm::tir::Buffer, void, void> const&, bool)>::AssignTypedLambda<tvm::__mk_TVM16::{lambda(tvm::te::Schedule, tvm::runtime::Array<tvm::runtime::ObjectRef, void> const&, tvm::runtime::String const&, tvm::runtime::Map<tvm::te::Tensor, tvm::tir::Buffer, void, void> const&, bool)#1}>(tvm::__mk_TVM16::{lambda(tvm::te::Schedule, tvm::runtime::Array<tvm::runtime::ObjectRef, void> const&, tvm::runtime::String const&, tvm::runtime::Map<tvm::te::Tensor, tvm::tir::Buffer, void, void> const&, bool)#1}, std::string)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}::operator()(tvm::runtime::TVMArgs, tvm::runtime::TVMArgs const&) const
  33: tvm::LowerSchedule(tvm::te::Schedule, tvm::runtime::Array<tvm::runtime::ObjectRef, void> const&, std::string const&, std::unordered_map<tvm::te::Tensor, tvm::tir::Buffer, std::hash<tvm::te::Tensor>, std::equal_to<tvm::te::Tensor>, std::allocator<std::pair<tvm::te::Tensor const, tvm::tir::Buffer> > > const&, bool)
  32: tvm::LowerWithPassList(tvm::IRModule, tvm::runtime::Array<tvm::transform::Pass, void>)
  31: tvm::transform::Pass::operator()(tvm::IRModule) const
  30: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  29: tvm::transform::SequentialNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  28: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  27: tvm::tir::transform::PrimFuncPassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  26: _ZN3tvm7runtime13PackedFuncObj9ExtractorINS0_16PackedFuncSubObjIZNS0_15TypedPackedFuncIFNS_3tir8PrimFuncES6_NS_8IRModuleENS_9transform11PassContextEEE17AssignTypedLambdaIZNS5_9transform14StorageFlattenEibEUlS6_S7_S9_E_EEvT_EUlRKNS0_7TVMArgsEPNS0_11TVMRetValueEE_EEE4CallEPKS1_SG_SK_
  25: tvm::tir::StorageFlatten(tvm::tir::PrimFunc, int, bool)
  24: tvm::transform::Pass::operator()(tvm::IRModule) const
  23: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  22: tvm::transform::SequentialNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  21: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  20: tvm::tir::transform::PrimFuncPassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  19: _ZN3tvm7runtime13PackedFun
  18: tvm::runtime::TypedPackedFunc<tvm::tir::PrimFunc (tvm::tir::PrimFunc, tvm::IRModule, tvm::transform::PassContext)>::AssignTypedLambda<tvm::tir::BufferBindUnwrapper::Pass()::{lambda(tvm::tir::PrimFunc, tvm::IRModule, tvm::transform::PassContext)#1}>(tvm::tir::BufferBindUnwrapper::Pass()::{lambda(tvm::tir::PrimFunc, tvm::IRModule, tvm::transform::PassContext)#1})::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}::operator()(tvm::runtime::TVMArgs const, tvm::runtime::TVMRetValue) const
  17: tvm::tir::StmtMutator::VisitStmt(tvm::tir::Stmt const&)
  16: tvm::tir::StmtFunctor<tvm::tir::Stmt (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
  15: _ZZN3tvm3tir11StmtFunctorI
  14: tvm::tir::StmtMutator::VisitStmt_(tvm::tir::ForNode const*)
  13: tvm::tir::StmtMutator::VisitStmt(tvm::tir::Stmt const&)
  12: tvm::tir::StmtFunctor<tvm::tir::Stmt (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
  11: _ZZN3tvm3tir11StmtFunctorI
  10: tvm::tir::StmtMutator::VisitStmt_(tvm::tir::ForNode const*)
  9: tvm::tir::StmtMutator::VisitStmt(tvm::tir::Stmt const&)
  8: tvm::tir::StmtFunctor<tvm::tir::Stmt (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
  7: _ZZN3tvm3tir11StmtFunctorI
  6: tvm::tir::BufferBindUnwrapper::VisitStmt_(tvm::tir::AttrStmtNode const*)
  5: tvm::tir::BufferBindUnwrapper::HandleBufferBindScope(tvm::tir::AttrStmtNode const*)
  4: tvm::tir::ArgBinder::BindBuffer(tvm::tir::Buffer const&, tvm::tir::Buffer const&, std::string const&, bool)
  3: tvm::tir::ArgBinder::BindArray(tvm::runtime::Array<tvm::PrimExpr, void> const&, tvm::runtime::Array<tvm::PrimExpr, void> const&, std::string const&)
  2: tvm::tir::ArgBinder::Bind_(tvm::PrimExpr const&, tvm::PrimExpr const&, std::string const&, bool)
  1: tvm::tir::BinderAddAssert(tvm::arith::Analyzer*, tvm::PrimExpr, std::string const&, std::vector<tvm::tir::Stmt, std::allocator<tvm::tir::Stmt> >*)
  0: _ZN3tvm7runtime6deta
  File "/workspace/tvm/src/tir/transforms/arg_binder.cc", line 40
TVMError: Bind have an unmet assertion: (bool)0,  on argument foobar.strides[1]

Note that this bug only occurs when the outer reduction axis has length one - all the others are fine. The issue occurs on the latest version of TVM, and on older versions like 0.9.0.

I've found a hack to work around this bug, and have used it in some of my PRs:

# TVM has a really strange bug where the outer reduction axis (kh_i) having length 1 causes the
# decl_buffer strides check to fail. height_stride is a dark magic workaround for this.
height_stride = in_channels * padded_w if kernel_h > 1 else in_channels
jump = (padded_w - kernel_w) * in_channels
tensordot_params = (in_dtype, kernel_h, jump, kernel_w * in_channels, suffix)

The hack is pretty gross though, so it would be nice to get a fix.

Steps to reproduce

The easiest way to reproduce this is to use a Colab notebook:

https://colab.research.google.com/drive/1Y9LXBdQxQD-FjNbW6cChuExse6IvMdGZ

You can also reproduce it using this script: bug_reproduction.py.txt

Metadata

Metadata

Assignees

No one assigned

    Labels

    needs-triagePRs or issues that need to be investigated by maintainers to find the right assignees to address ittype: bug

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions