From f67ad2fd4f5cc51f6ae2f102d7e47f6cce478766 Mon Sep 17 00:00:00 2001 From: Jan Pfeifer Date: Thu, 8 May 2025 22:59:38 +0200 Subject: [PATCH 1/4] Converting shapeinference to return errors. --- backends/shapeinference/shapeinference.go | 185 +++++++++++------- .../shapeinference/shapeinference_test.go | 64 ++++-- backends/simplego/ops.go | 5 +- docs/CHANGELOG.md | 3 +- 4 files changed, 161 insertions(+), 96 deletions(-) diff --git a/backends/shapeinference/shapeinference.go b/backends/shapeinference/shapeinference.go index 7b85a236..8255706c 100644 --- a/backends/shapeinference/shapeinference.go +++ b/backends/shapeinference/shapeinference.go @@ -1,16 +1,14 @@ -// Package shapeinference calculates the shape resulting from operations. +// Package shapeinference calculates the shape resulting from operations, and validates its inputs. // // This can be useful for new backends to test and help plan for buffer space for temporary or output buffers. // -// It defines BinaryOp function for shape inference for the majority of binary functions, using the standard +// It defines a BinaryOp function for shape inference for the majority of binary functions, using the standard // broadcasting rules. // // The majority of the unary functions don't change the shape, except those that explicitly say that in their name, // like Reshape, etc. // // For the remainder ops, it defines one function per OpType. -// -// It also defines some classes of operations that can be used. package shapeinference import ( @@ -19,6 +17,7 @@ import ( "github.com/gomlx/gomlx/types" "github.com/gomlx/gomlx/types/shapes" "github.com/gomlx/gopjrt/dtypes" + "github.com/pkg/errors" "slices" ) @@ -175,119 +174,141 @@ var ( // // It may throw (panic) an exception if the data type (shape.DType) is invalid for the operation -- e.g.: non-matching // dtypes, or LogicalAnd not having booleans (dtype.Bool) as input. -func BinaryOp(opType backends.OpType, lhsShape, rhsShape shapes.Shape) shapes.Shape { +func BinaryOp(opType backends.OpType, lhsShape, rhsShape shapes.Shape) (output shapes.Shape, err error) { if !StandardBinaryOperations.Has(opType) { - exceptions.Panicf("operations %s is not in the StandardBinaryOperations set, cannot process it with BinaryOp", opType) + err = errors.Errorf("operations %s is not in the StandardBinaryOperations set, cannot process it with BinaryOp", opType) + return } if lhsShape.DType == dtypes.InvalidDType || rhsShape.DType == dtypes.InvalidDType { - exceptions.Panicf("invalid shape for %s or %s for BinaryOp %s", lhsShape, rhsShape, opType) + err = errors.Errorf("invalid shape for %s or %s for BinaryOp %s", lhsShape, rhsShape, opType) + return } if lhsShape.DType != rhsShape.DType { - exceptions.Panicf("data types (DType) for BinaryOp %s must match, got %s and %s", opType, lhsShape, rhsShape) + err = errors.Errorf("data types (DType) for BinaryOp %s must match, got %s and %s", opType, lhsShape, rhsShape) + return } if BooleanOperations.Has(opType) && lhsShape.DType != dtypes.Bool { - exceptions.Panicf("logical BinaryOp %s must have boolean (dtype.Bool) data types as input, got %s", opType, lhsShape) + err = errors.Errorf("logical BinaryOp %s must have boolean (dtype.Bool) data types as input, got %s", opType, lhsShape) + return } if BitwiseOperations.Has(opType) && !lhsShape.DType.IsInt() { - exceptions.Panicf("bitwise BinaryOp %s must have an integer (Int8, UInt8, Int32, ...) data type as input, got %s", opType, lhsShape) + err = errors.Errorf("bitwise BinaryOp %s must have an integer (Int8, UInt8, Int32, ...) data type as input, got %s", opType, lhsShape) + return } if NumberOperations.Has(opType) && !(lhsShape.DType.IsInt() || lhsShape.DType.IsFloat() || lhsShape.DType.IsComplex()) { - exceptions.Panicf("numeric BinaryOp %s must have a number (Int32, Float32, Complex64, ...) data type as input, got %s", opType, lhsShape) + err = errors.Errorf("numeric BinaryOp %s must have a number (Int32, Float32, Complex64, ...) data type as input, got %s", opType, lhsShape) + return } if FloatOperations.Has(opType) && !lhsShape.DType.IsFloat() { - exceptions.Panicf("float BinaryOp %s must have a float (Float32, Float64, ...) data type as input, got %s", opType, lhsShape) + err = errors.Errorf("float BinaryOp %s must have a float (Float32, Float64, ...) data type as input, got %s", opType, lhsShape) + return } if FloatOrComplexOperations.Has(opType) && !(lhsShape.DType.IsFloat() || lhsShape.DType.IsComplex()) { - exceptions.Panicf("float/complex BinaryOp %s must have a float or complex (Float32, Complex64, ...) data type as input, got %s", opType, lhsShape) + err = errors.Errorf("float/complex BinaryOp %s must have a float or complex (Float32, Complex64, ...) data type as input, got %s", opType, lhsShape) + return } if ComplexOperations.Has(opType) && !lhsShape.DType.IsComplex() { - exceptions.Panicf("complex BinaryOp %s must have a complex (Complex64, Complex128) data type as input, got %s", opType, lhsShape) + err = errors.Errorf("complex BinaryOp %s must have a complex (Complex64, Complex128) data type as input, got %s", opType, lhsShape) + return } return binaryOpImpl(opType, lhsShape, rhsShape) } -func binaryOpImpl(opType backends.OpType, lhsShape, rhsShape shapes.Shape) shapes.Shape { +func binaryOpImpl(opType backends.OpType, lhsShape, rhsShape shapes.Shape) (output shapes.Shape, err error) { // Trivial cases: if one of the sides is a scalar, return the other side shape. if lhsShape.IsScalar() { - return rhsShape + return rhsShape, nil } if rhsShape.IsScalar() { - return lhsShape + return lhsShape, nil } // Other cases, either the dimensions match or one of them is 1. if lhsShape.Rank() != rhsShape.Rank() { - exceptions.Panicf("if operands are not scalars, their rank must match for BinaryOp (%s), got shapes %s and %s", + err = errors.Errorf("if operands are not scalars, their rank must match for BinaryOp (%s), got shapes %s and %s", opType, lhsShape, rhsShape) } - shape := lhsShape.Clone() - for axis := range shape.Rank() { + output = lhsShape.Clone() + for axis := range output.Rank() { lhsDim := lhsShape.Dimensions[axis] rhsDim := rhsShape.Dimensions[axis] if lhsDim != 1 && rhsDim != 1 && lhsDim != rhsDim { - exceptions.Panicf("dimension of axis #%d doesn't match and cannot be broadcast for BinaryOp (%s), got shapes %s and %s", + err = errors.Errorf("dimension of axis #%d doesn't match and cannot be broadcast for BinaryOp (%s), got shapes %s and %s", axis, opType, lhsShape, rhsShape) + return } - shape.Dimensions[axis] = max(lhsDim, rhsDim) + output.Dimensions[axis] = max(lhsDim, rhsDim) } - return shape + return } // ComparisonOp returns the broadcast shape with dtype set to Bool, for comparison operations (Equal, LessThan, GreaterOrEqual, etc.) -func ComparisonOp(opType backends.OpType, lhsShape, rhsShape shapes.Shape) shapes.Shape { +func ComparisonOp(opType backends.OpType, lhsShape, rhsShape shapes.Shape) (output shapes.Shape, err error) { if !ComparisonOperations.Has(opType) { - exceptions.Panicf("operation %s is not in the ComparisonOperations set, cannot process it with ComparisonOp", opType) + err = errors.Errorf("operation %s is not in the ComparisonOperations set, cannot process it with ComparisonOp", opType) + return } if lhsShape.DType == dtypes.InvalidDType || rhsShape.DType == dtypes.InvalidDType { - exceptions.Panicf("invalid shape for %s or %s for ComparisonOp %s", lhsShape, rhsShape, opType) + err = errors.Errorf("invalid shape for %s or %s for ComparisonOp %s", lhsShape, rhsShape, opType) + return } if lhsShape.DType != rhsShape.DType { - exceptions.Panicf("data types (DType) for ComparisonOp %s must match, got %s and %s", opType, lhsShape, rhsShape) + err = errors.Errorf("data types (DType) for ComparisonOp %s must match, got %s and %s", opType, lhsShape, rhsShape) + return } if !NumberOperations.Has(opType) { - exceptions.Panicf("operation %s is not in the NumberOperations set, cannot process it with ComparisonOp", opType) + err = errors.Errorf("operation %s is not in the NumberOperations set, cannot process it with ComparisonOp", opType) + return } - shape := binaryOpImpl(opType, lhsShape, rhsShape) - shape.DType = dtypes.Bool - return shape + output, err = binaryOpImpl(opType, lhsShape, rhsShape) + output.DType = dtypes.Bool + return } -// UnaryOp checks the validity of the data type for StandardUnaryOperations set and throws an exception -// (panic) in case of mismatch. -// -// It returns the same shape as the operand -- there is no broadcast on standard unary operations. -func UnaryOp(opType backends.OpType, operand shapes.Shape) shapes.Shape { +// UnaryOp checks the validity of the data type for StandardUnaryOperations and returns either an error or +// the output shape, which it the same as the operand. +func UnaryOp(opType backends.OpType, operand shapes.Shape) (output shapes.Shape, err error) { if !StandardUnaryOperations.Has(opType) { - exceptions.Panicf("operation %s is not in the StandardUnaryOperations set, cannot process it with UnaryOp", opType) + err = errors.Errorf("operation %s is not in the StandardUnaryOperations set, cannot process it with UnaryOp", opType) + return } if operand.DType == dtypes.InvalidDType { - exceptions.Panicf("invalid shape %s for UnaryOp %s", operand, opType) + err = errors.Errorf("invalid shape %s for UnaryOp %s", operand, opType) + return } if BooleanOperations.Has(opType) && operand.DType != dtypes.Bool { - exceptions.Panicf("logical UnaryOp %s must have boolean (dtype.Bool) data types as input, got %s", opType, operand) + err = errors.Errorf("logical UnaryOp %s must have boolean (dtype.Bool) data types as input, got %s", opType, operand) + return } if BitwiseOperations.Has(opType) && !operand.DType.IsInt() { - exceptions.Panicf("bitwise UnaryOp %s must have an integer (Int8, UInt8, Int32, ...) data type as input, got %s", opType, operand) + err = errors.Errorf("bitwise UnaryOp %s must have an integer (Int8, UInt8, Int32, ...) data type as input, got %s", opType, operand) + return } if SignedNumberOperations.Has(opType) && (operand.DType.IsUnsigned() || !(operand.DType.IsInt() || operand.DType.IsFloat() || operand.DType.IsComplex())) { - exceptions.Panicf("signed UnaryOp %s must have a signed data type as input, got %s", opType, operand) + err = errors.Errorf("signed UnaryOp %s must have a signed data type as input, got %s", opType, operand) + return } if NumberOperations.Has(opType) && !(operand.DType.IsInt() || operand.DType.IsFloat() || operand.DType.IsComplex()) { - exceptions.Panicf("numeric UnaryOp %s must have a number (Int32, Float32, Complex64, ...) data type as input, got %s", opType, operand) + err = errors.Errorf("numeric UnaryOp %s must have a number (Int32, Float32, Complex64, ...) data type as input, got %s", opType, operand) + return } if FloatOperations.Has(opType) && !operand.DType.IsFloat() { - exceptions.Panicf("float UnaryOp %s must have a float (Float32, Float64, ...) data type as input, got %s", opType, operand) + err = errors.Errorf("float UnaryOp %s must have a float (Float32, Float64, ...) data type as input, got %s", opType, operand) + return } if FloatOrComplexOperations.Has(opType) && !(operand.DType.IsFloat() || operand.DType.IsComplex()) { - exceptions.Panicf("float/complex UnaryOp %s must have a float or complex (Float32, Complex64, ...) data type as input, got %s", opType, operand) + err = errors.Errorf("float/complex UnaryOp %s must have a float or complex (Float32, Complex64, ...) data type as input, got %s", opType, operand) + return } if ComplexOperations.Has(opType) && !operand.DType.IsComplex() { - exceptions.Panicf("complex UnaryOp %s must have a complex (Complex64, Complex128) data type as input, got %s", opType, operand) + err = errors.Errorf("complex UnaryOp %s must have a complex (Complex64, Complex128) data type as input, got %s", opType, operand) + return } - return operand + output = operand + return } // WhereOp returns the shape resulting from the Where operation. @@ -297,16 +318,18 @@ func UnaryOp(opType backends.OpType, operand shapes.Shape) shapes.Shape { // 1. onTrue and onFalse must have the exact same shape. // 2. condition must either be a scalar or match the shape of onTrue and onFalse, except for the DType that // must be Bool. -func WhereOp(condition, onTrue, onFalse shapes.Shape) shapes.Shape { +func WhereOp(condition, onTrue, onFalse shapes.Shape) (output shapes.Shape, err error) { if condition.DType != dtypes.Bool { - exceptions.Panicf("condition for Where() must be a boolean, got %s instead", condition) + err = errors.Errorf("condition for Where() must be a boolean, got %s instead", condition) + return } if !onTrue.IsScalar() && !onFalse.IsScalar() && !onTrue.Equal(onFalse) { - exceptions.Panicf("onTrue (%s) and onFalse (%s) values for Where() must either be scalar or match each other's shape", + err = errors.Errorf("onTrue (%s) and onFalse (%s) values for Where() must either be scalar or match each other's shape", onTrue, onFalse) + return } - output := onTrue + output = onTrue if output.IsScalar() { output = onFalse if output.IsScalar() && !condition.IsScalar() { @@ -316,37 +339,40 @@ func WhereOp(condition, onTrue, onFalse shapes.Shape) shapes.Shape { } if !condition.IsScalar() && slices.Compare(condition.Dimensions, output.Dimensions) != 0 { - exceptions.Panicf("condition for Where() must either be a scalar or match the output shape (not the DType), instead got shapes condition=%s, onTrue=%s and onFalse=%s", + err = errors.Errorf("condition for Where() must either be a scalar or match the output shape (not the DType), instead got shapes condition=%s, onTrue=%s and onFalse=%s", condition, onTrue, onFalse) + return } - return output + return } // ReshapeOp to the given dimensions: trivial output shape, but this function also checks // that the sizes are the same. // // Notice the backends.Reshape doesn't support auto-scaling dimensions (set to -1), as graph.Reshape does. -func ReshapeOp(operand shapes.Shape, dims []int) shapes.Shape { - output := shapes.Make(operand.DType, dims...) +func ReshapeOp(operand shapes.Shape, dims []int) (output shapes.Shape, err error) { + output = shapes.Make(operand.DType, dims...) if operand.Size() != output.Size() { - exceptions.Panicf("Reshape() cannot reshape %s to dimensions %v, their size don't match", + err = errors.Errorf("Reshape() cannot reshape %s to dimensions %v, their size don't match", operand, dims) + return shapes.Invalid(), err } - return output + return } // TransposeOp all axes of the operand. // There must be one value in permutations for each axis in the operand. // The output will have: output.Shape.Dimension[ii] = operand.Shape.Dimension[permutations[i]]. -func TransposeOp(operand shapes.Shape, permutations []int) shapes.Shape { +func TransposeOp(operand shapes.Shape, permutations []int) (output shapes.Shape, err error) { rank := operand.Rank() if len(permutations) != rank { - exceptions.Panicf("Transpose() requires all axes permutations to be defined, operand has shape %s, but %d permutations were given", + err = errors.Errorf("Transpose() requires all axes permutations to be defined, operand has shape %s, but %d permutations were given", operand, len(permutations)) + return } if rank == 0 { - return operand + return operand, nil } // Check permutation axes are within range and unique. @@ -354,39 +380,51 @@ func TransposeOp(operand shapes.Shape, permutations []int) shapes.Shape { slices.Sort(axesSet) for ii, srcAxis := range axesSet { if srcAxis < 0 || srcAxis >= rank { - exceptions.Panicf("invalid permutation axis %d given to Transpose(%s), it must be within the range of its rank", + err = errors.Errorf("invalid permutation axis %d given to Transpose(%s), it must be within the range of its rank", srcAxis, operand) + return } if ii > 0 && srcAxis == axesSet[ii-1] { - exceptions.Panicf("invalid permutations given to Transpose(%s, %v), there cannot be any repeated axis, each must appear exactly once", + err = errors.Errorf("invalid permutations given to Transpose(%s, %v), there cannot be any repeated axis, each must appear exactly once", operand, permutations) + return } } - output := operand.Clone() + output = operand.Clone() for axis := range output.Dimensions { srcAxis := permutations[axis] output.Dimensions[axis] = operand.Dimensions[srcAxis] } - return output + return } // BroadcastOp adds the prefixDims to the start of the shape. -func BroadcastOp(operand shapes.Shape, prefixDims []int) shapes.Shape { +func BroadcastOp(operand shapes.Shape, prefixDims []int) (output shapes.Shape, err error) { + if operand.DType == dtypes.InvalidDType { + err = errors.Errorf("invalid shape %s for BroadcastOp", operand) + return + } if len(prefixDims) == 0 { - return operand + return operand, nil } - output := shapes.Make(operand.DType) + for _, dim := range prefixDims { + if dim <= 0 { + err = errors.Errorf("Invalid prefix dimensions %v for BroadcastOp, they must be positive", prefixDims) + return + } + } + output = shapes.Make(operand.DType) output.Dimensions = make([]int, len(prefixDims)+operand.Rank()) copy(output.Dimensions, prefixDims) copy(output.Dimensions[len(prefixDims):], operand.Dimensions) - return output + return } -// BroadcastInDimOp verifies the arguments are valid. The output shape is already known, so nothing is returned. -func BroadcastInDimOp(operand, outputShape shapes.Shape, broadcastAxes []int) { +// BroadcastInDimOp verifies that the arguments are valid. The output shape is already known, so nothing is returned. +func BroadcastInDimOp(operand, outputShape shapes.Shape, broadcastAxes []int) error { if len(broadcastAxes) != operand.Rank() { - exceptions.Panicf("there must be exactly one broadcastAxes (%v) per axis in the operand (%s)", + return errors.Errorf("there must be exactly one broadcastAxes (%v) per axis in the operand (%s)", broadcastAxes, operand) } @@ -394,21 +432,22 @@ func BroadcastInDimOp(operand, outputShape shapes.Shape, broadcastAxes []int) { preservedSet := types.MakeSet[int](len(broadcastAxes)) for axisInOperand, axisInOutput := range broadcastAxes { if axisInOutput < 0 || axisInOutput >= outputShape.Rank() { - exceptions.Panicf("broadcastAxes (%v) defines a value out-of-range (%d-th value -> %d), they must be between 0 and outputShape.Rank()-1=%d", + return errors.Errorf("broadcastAxes (%v) defines a value out-of-range (%d-th value -> %d), they must be between 0 and outputShape.Rank()-1=%d", broadcastAxes, axisInOperand, axisInOutput, outputShape.Rank()-1) } if preservedSet.Has(axisInOutput) { - exceptions.Panicf("broadcastAxes (%v) repeats axis %d (broadcastAxes[%d]), they must be all unique and between 0 and outputShape.Rank()-1=%d", + return errors.Errorf("broadcastAxes (%v) repeats axis %d (broadcastAxes[%d]), they must be all unique and between 0 and outputShape.Rank()-1=%d", broadcastAxes, axisInOutput, axisInOperand, outputShape.Rank()-1) } preservedSet.Insert(axisInOutput) if operand.Dimensions[axisInOperand] != 1 && operand.Dimensions[axisInOperand] != outputShape.Dimensions[axisInOutput] { - exceptions.Panicf("the values of outputShape (%v) that are being broadcast (listed in broadcastAxes) "+ + return errors.Errorf("the values of outputShape (%v) that are being broadcast (listed in broadcastAxes) "+ "must match the corresponding value in the operand shape (%s) or be 1 (if broadcasting), "+ "but the value of outputShape.Dimensions[%d]=%d does not match the value in operand.Shape().Dimensions[%d]=%d", outputShape, operand, axisInOutput, outputShape.Dimensions[axisInOutput], axisInOperand, operand.Dimensions[axisInOperand]) } } + return nil } // ReduceOp works for the ReduceMax, ReduceMin, ReduceSum and ReduceProduct ops. diff --git a/backends/shapeinference/shapeinference_test.go b/backends/shapeinference/shapeinference_test.go index 8b52af1d..9e594563 100644 --- a/backends/shapeinference/shapeinference_test.go +++ b/backends/shapeinference/shapeinference_test.go @@ -20,63 +20,85 @@ var ( MS = shapes.Make ) +// must1 panics if there is an error. +func must1[T any](value T, err error) T { + if err != nil { + panic(err) + } + return value +} + func TestBinaryOp(t *testing.T) { // Invalid data types check. - require.Panics(t, func() { BinaryOp(OpTypeLogicalAnd, MS(I8), MS(I8)) }) - require.Panics(t, func() { BinaryOp(OpTypeMul, MS(Bool, 1), MS(Bool, 1)) }) - require.Panics(t, func() { BinaryOp(OpTypeMul, MS(Bool, 1), MS(Bool, 1)) }) - require.Panics(t, func() { BinaryOp(OpTypeBitwiseXor, MS(F32, 1), MS(F32, 1)) }) + var err error + _, err = BinaryOp(OpTypeLogicalAnd, MS(I8), MS(I8)) + require.Error(t, err) + _, err = BinaryOp(OpTypeMul, MS(Bool, 1), MS(Bool, 1)) + require.Error(t, err) + _, err = BinaryOp(OpTypeMul, MS(Bool, 1), MS(Bool, 1)) + require.Error(t, err) + _, err = BinaryOp(OpTypeBitwiseXor, MS(F32, 1), MS(F32, 1)) + require.Error(t, err) // Invalid operation type (not binary op). - require.Panics(t, func() { BinaryOp(OpTypeExp, MS(F32), MS(F32)) }) + _, err = BinaryOp(OpTypeExp, MS(F32), MS(F32)) + require.Error(t, err) // The same shape should be ok. + var output shapes.Shape intMatrixShape := MS(I8, 3, 3) - require.True(t, intMatrixShape.Equal(BinaryOp(OpTypeBitwiseOr, intMatrixShape, intMatrixShape))) + output, err = BinaryOp(OpTypeBitwiseOr, intMatrixShape, intMatrixShape) + require.NoError(t, err) + require.True(t, intMatrixShape.Equal(output)) // Scalar with matrix. scalarShape := MS(F32) matrixShape := MS(F32, 2, 3) expectedShape := MS(F32, 2, 3) - require.True(t, scalarShape.Equal(BinaryOp(OpTypeAdd, scalarShape, scalarShape))) - require.True(t, expectedShape.Equal(BinaryOp(OpTypeAdd, scalarShape, matrixShape))) + output, err = BinaryOp(OpTypeAdd, scalarShape, scalarShape) + require.NoError(t, err) + require.True(t, scalarShape.Equal(output)) + output, err = BinaryOp(OpTypeAdd, scalarShape, matrixShape) + require.NoError(t, err) + require.True(t, expectedShape.Equal(output)) // Broadcasting on both sides. shape1 := MS(F32, 2, 1, 3) shape2 := MS(F32, 1, 4, 3) expectedBroadcastShape := MS(F32, 2, 4, 3) - require.True(t, expectedBroadcastShape.Equal(BinaryOp(OpTypeMul, shape1, shape2))) + require.True(t, expectedBroadcastShape.Equal(must1(BinaryOp(OpTypeMul, shape1, shape2)))) // Matrix with scalar. - require.True(t, expectedShape.Equal(BinaryOp(OpTypeAdd, matrixShape, scalarShape))) + require.True(t, expectedShape.Equal(must1(BinaryOp(OpTypeAdd, matrixShape, scalarShape)))) // Invalid broadcasting shapes. invalidShape1 := MS(F32, 2, 3) invalidShape2 := MS(F32, 3, 2) - require.Panics(t, func() { BinaryOp(OpTypeAdd, invalidShape1, invalidShape2) }) + _, err = BinaryOp(OpTypeAdd, invalidShape1, invalidShape2) + require.Error(t, err) } func TestUnaryOp(t *testing.T) { // Invalid data types check. - require.Panics(t, func() { UnaryOp(OpTypeLogicalNot, MS(F32)) }) - require.Panics(t, func() { UnaryOp(OpTypeLogicalNot, MS(I8)) }) - require.Panics(t, func() { UnaryOp(OpTypeBitwiseNot, MS(F32)) }) - require.Panics(t, func() { UnaryOp(OpTypeNeg, MS(Bool)) }) + require.Panics(t, func() { must1(UnaryOp(OpTypeLogicalNot, MS(F32))) }) + require.Panics(t, func() { must1(UnaryOp(OpTypeLogicalNot, MS(I8))) }) + require.Panics(t, func() { must1(UnaryOp(OpTypeBitwiseNot, MS(F32))) }) + require.Panics(t, func() { must1(UnaryOp(OpTypeNeg, MS(Bool))) }) // Invalid operation type (not unary op). - require.Panics(t, func() { UnaryOp(OpTypeAdd, MS(F32)) }) - require.Panics(t, func() { UnaryOp(OpTypeNeg, MS(U64)) }) + require.Panics(t, func() { must1(UnaryOp(OpTypeAdd, MS(F32))) }) + require.Panics(t, func() { must1(UnaryOp(OpTypeNeg, MS(U64))) }) // Valid operations boolShape := MS(Bool, 2, 3) - require.True(t, boolShape.Equal(UnaryOp(OpTypeLogicalNot, boolShape))) + require.True(t, boolShape.Equal(must1(UnaryOp(OpTypeLogicalNot, boolShape)))) intShape := MS(I8, 3, 3) - require.True(t, intShape.Equal(UnaryOp(OpTypeBitwiseNot, intShape))) + require.True(t, intShape.Equal(must1(UnaryOp(OpTypeBitwiseNot, intShape)))) floatShape := MS(F32, 2, 3) - require.True(t, floatShape.Equal(UnaryOp(OpTypeExp, floatShape))) - require.True(t, floatShape.Equal(UnaryOp(OpTypeNeg, floatShape))) + require.True(t, floatShape.Equal(must1(UnaryOp(OpTypeExp, floatShape)))) + require.True(t, floatShape.Equal(must1(UnaryOp(OpTypeNeg, floatShape)))) } func TestGatherOp(t *testing.T) { diff --git a/backends/simplego/ops.go b/backends/simplego/ops.go index baf77fd3..a12d078e 100644 --- a/backends/simplego/ops.go +++ b/backends/simplego/ops.go @@ -89,7 +89,10 @@ func (b *Builder) Identity(operandOp backends.Op) backends.Op { func (b *Builder) Where(conditionOp, onTrueOp, onFalseOp backends.Op) backends.Op { inputs := b.checkOps("Where", conditionOp, onTrueOp, onFalseOp) condition, onTrue, onFalse := inputs[0], inputs[1], inputs[2] - outputShape := shapeinference.WhereOp(condition.shape, onTrue.shape, onFalse.shape) + outputShape, err := shapeinference.WhereOp(condition.shape, onTrue.shape, onFalse.shape) + if err != nil { + panic(err) + } return b.newNode(backends.OpTypeWhere, outputShape, condition, onTrue, onFalse) } diff --git a/docs/CHANGELOG.md b/docs/CHANGELOG.md index 86cefdbd..752c20a6 100644 --- a/docs/CHANGELOG.md +++ b/docs/CHANGELOG.md @@ -10,6 +10,7 @@ * Only include XLA by default on linux/amd64 platforms. * Package `types/tensors`: * Removed dependency to `gopjrt/pjrt` -- otherwise we'll always need to install the C/C++ library. +* gofmt cleanups by @zjtv # v0.19.1: 2025/04/30 SimpleGo fixes and new ops; New XLA, requires Gopjrt v0.7.0 update. @@ -28,7 +29,7 @@ * Added sub-package `notimplemented`: helper to implement new backends. * Added sub-package `shapeinference`: helper to implement new backends. * Added sub-package `default` which includes the default packages. - * Added `List()` function to returned the currently registered (compiled-in) backends. + * Added `List()` function that returns the currently registered (compiled-in) backends. * Package `checkpoints` * Added `Config.FromEmbed` that allows loading a checkpoint from an embedded variable. * Package `graph`: From a15bf26f124425e142c0d5b38e2de504575406ac Mon Sep 17 00:00:00 2001 From: Jan Pfeifer Date: Thu, 8 May 2025 23:05:33 +0200 Subject: [PATCH 2/4] Converting shapeinference to return errors. --- backends/shapeinference/shapeinference.go | 18 +++++++++++------- backends/shapeinference/shapeinference_test.go | 12 ++++++------ 2 files changed, 17 insertions(+), 13 deletions(-) diff --git a/backends/shapeinference/shapeinference.go b/backends/shapeinference/shapeinference.go index 8255706c..b400e1c2 100644 --- a/backends/shapeinference/shapeinference.go +++ b/backends/shapeinference/shapeinference.go @@ -780,21 +780,25 @@ func SliceOp(operand shapes.Shape, starts, limits, strides []int) shapes.Shape { // ArgMinMaxOp calculates the output shape for an ArgMinMax operation. // It will be the shape of the operand minus the reduce axis. -func ArgMinMaxOp(operand shapes.Shape, axis int, outputDType dtypes.DType) shapes.Shape { +func ArgMinMaxOp(operand shapes.Shape, axis int, outputDType dtypes.DType) (output shapes.Shape, err error) { if !outputDType.IsInt() { - exceptions.Panicf("ArgMinMax outputDType must be an integer type, got %s", outputDType) + err = errors.Errorf("ArgMinMax outputDType must be an integer type, got %s", outputDType) + return } if !operand.DType.IsFloat() && !operand.DType.IsInt() { - exceptions.Panicf("ArgMinMax operand DType must be a floating point or integer type, got %s", operand) + err = errors.Errorf("ArgMinMax operand DType must be a floating point or integer type, got %s", operand) + return } if operand.IsScalar() { - exceptions.Panicf("ArgMinMax requires a non-scalar operand, got %s", operand) + err = errors.Errorf("ArgMinMax requires a non-scalar operand, got %s", operand) + return } if axis < 0 || axis >= operand.Rank() { - exceptions.Panicf("ArgMinMax axis %d is out of range for operand %s", axis, operand) + err = errors.Errorf("ArgMinMax axis %d is out of range for operand %s", axis, operand) + return } newDims := slices.Clone(operand.Dimensions) newDims = slices.Delete(newDims, axis, axis+1) - output := shapes.Make(outputDType, newDims...) - return output + output = shapes.Make(outputDType, newDims...) + return } diff --git a/backends/shapeinference/shapeinference_test.go b/backends/shapeinference/shapeinference_test.go index 9e594563..1cc7aaf7 100644 --- a/backends/shapeinference/shapeinference_test.go +++ b/backends/shapeinference/shapeinference_test.go @@ -378,35 +378,35 @@ func TestArgMinMaxOp(t *testing.T) { // Case 1: 1D tensor operand1 := MS(F32, 10) expected1 := MS(I32) - output1 := ArgMinMaxOp(operand1, 0, I32) + output1 := must1(ArgMinMaxOp(operand1, 0, I32)) require.True(t, expected1.Equal(output1), "Valid Case 1 Failed: Expected %s, got %s", expected1, output1) // Case 2: 2D tensor, single axis operand2 := MS(F32, 5, 6) expected2 := MS(I8, 5) - output2 := ArgMinMaxOp(operand2, 1, expected2.DType) + output2 := must1(ArgMinMaxOp(operand2, 1, expected2.DType)) require.True(t, expected2.Equal(output2), "Valid Case 2 Failed: Expected %s, got %s", expected2, output2) // Case 3: 3D tensor, multiple axes operand3 := MS(F32, 4, 5, 6) expected3 := MS(U64, 5, 6) - output3 := ArgMinMaxOp(operand3, 0, expected3.DType) + output3 := must1(ArgMinMaxOp(operand3, 0, expected3.DType)) require.True(t, expected3.Equal(output3), "Valid Case 3 Failed: Expected %s, got %s", expected3, output3) // --- Error Cases --- // Error 1: Invalid operand DType require.Panics(t, func() { - ArgMinMaxOp(shapes.Make(dtypes.InvalidDType, 10), 0, I32) + must1(ArgMinMaxOp(shapes.Make(dtypes.InvalidDType, 10), 0, I32)) }, "Error Case 1 Failed: Invalid operand DType") // Error 2: Invalid axis (out of bounds) require.Panics(t, func() { - ArgMinMaxOp(operand1, 1, I32) // operand1 is rank 1, axis 1 invalid + must1(ArgMinMaxOp(operand1, 1, I32)) // operand1 is rank 1, axis 1 invalid }, "Error Case 2 Failed: Invalid axis (out of bounds)") // Error 3: Negative axis require.Panics(t, func() { - ArgMinMaxOp(operand2, -1, I32) + must1(ArgMinMaxOp(operand2, -1, I32)) }, "Error Case 3 Failed: Negative axis") } From ad9ae3f6d30ae37c38bcb64ee2adbf4f9ef6b10a Mon Sep 17 00:00:00 2001 From: Jan Pfeifer Date: Fri, 9 May 2025 08:07:21 +0200 Subject: [PATCH 3/4] Refactoring shapeinference to return errors. --- backends/shapeinference/shapeinference.go | 127 +++++++++++----------- 1 file changed, 63 insertions(+), 64 deletions(-) diff --git a/backends/shapeinference/shapeinference.go b/backends/shapeinference/shapeinference.go index b400e1c2..d82d6f1e 100644 --- a/backends/shapeinference/shapeinference.go +++ b/backends/shapeinference/shapeinference.go @@ -12,7 +12,6 @@ package shapeinference import ( - "github.com/gomlx/exceptions" "github.com/gomlx/gomlx/backends" "github.com/gomlx/gomlx/types" "github.com/gomlx/gomlx/types/shapes" @@ -44,7 +43,7 @@ var ( backends.OpTypeClz, ) - // NumberOperations can take any type of number as input: integer, float or complex. + // NumberOperations can take any type of number as input: integers, floats, or complex numbers. NumberOperations = types.SetWith( backends.OpTypeAdd, backends.OpTypeSub, @@ -172,7 +171,7 @@ var ( // operations that have two operands usually named lhs (left-hand-side) and rhs (right-hand-side), and they are usually // commutative (invariant to order). // -// It may throw (panic) an exception if the data type (shape.DType) is invalid for the operation -- e.g.: non-matching +// It returns an error if the data type (shape.DType) is invalid for the operation -- e.g.: non-matching // dtypes, or LogicalAnd not having booleans (dtype.Bool) as input. func BinaryOp(opType backends.OpType, lhsShape, rhsShape shapes.Shape) (output shapes.Shape, err error) { if !StandardBinaryOperations.Has(opType) { @@ -268,7 +267,7 @@ func ComparisonOp(opType backends.OpType, lhsShape, rhsShape shapes.Shape) (outp } // UnaryOp checks the validity of the data type for StandardUnaryOperations and returns either an error or -// the output shape, which it the same as the operand. +// the output shape, which is the same as the operand. func UnaryOp(opType backends.OpType, operand shapes.Shape) (output shapes.Shape, err error) { if !StandardUnaryOperations.Has(opType) { err = errors.Errorf("operation %s is not in the StandardUnaryOperations set, cannot process it with UnaryOp", opType) @@ -313,10 +312,10 @@ func UnaryOp(opType backends.OpType, operand shapes.Shape) (output shapes.Shape, // WhereOp returns the shape resulting from the Where operation. // -// Shape constraints: +// Shape constraints for the operation: // -// 1. onTrue and onFalse must have the exact same shape. -// 2. condition must either be a scalar or match the shape of onTrue and onFalse, except for the DType that +// 1. The onTrue and onFalse must have the exact same shape, or one can be a scalar. +// 2. The condition must either be a scalar or match the shape of onTrue or onFalse, except for the DType that // must be Bool. func WhereOp(condition, onTrue, onFalse shapes.Shape) (output shapes.Shape, err error) { if condition.DType != dtypes.Bool { @@ -451,18 +450,18 @@ func BroadcastInDimOp(operand, outputShape shapes.Shape, broadcastAxes []int) er } // ReduceOp works for the ReduceMax, ReduceMin, ReduceSum and ReduceProduct ops. -func ReduceOp(operand shapes.Shape, axes []int) shapes.Shape { +func ReduceOp(operand shapes.Shape, axes []int) (output shapes.Shape, err error) { if len(axes) == 0 { - return operand + return operand, nil } - output := shapes.Make(operand.DType) + output = shapes.Make(operand.DType) outputRank := operand.Rank() - len(axes) if outputRank > 0 { // Copy over dimensions that will stay. output.Dimensions = make([]int, 0, outputRank) for _, axis := range axes { if axis < 0 || axis >= operand.Rank() { - exceptions.Panicf("Reduce operation require each axis to be 0 <= axis < rank, but got invalid axis %d for shape %s", axis, operand) + return shapes.Invalid(), errors.Errorf("Reduce operation require each axis to be 0 <= axis < rank, but got invalid axis %d for shape %s", axis, operand) } } axesSet := types.SetWith(axes...) @@ -472,12 +471,12 @@ func ReduceOp(operand shapes.Shape, axes []int) shapes.Shape { } } } - return output + return } // GatherOp returns the output shape of a Gather operation. func GatherOp(operand, startIndices shapes.Shape, indexVectorAxis int, offsetOutputAxes, collapsedSliceAxes, - startIndexMap, sliceSizes []int, indicesAreSorted bool) shapes.Shape { + startIndexMap, sliceSizes []int, indicesAreSorted bool) (output shapes.Shape, err error) { //fmt.Printf("GatherOp parameters:\n"+ // " operand: %v\n"+ // " startIndices: %v\n"+ @@ -492,55 +491,55 @@ func GatherOp(operand, startIndices shapes.Shape, indexVectorAxis int, offsetOut _ = indicesAreSorted // Not used for shape inference. if operand.IsScalar() { - exceptions.Panicf("Gather() requires a non-scalar operand, got %s", operand) + return output, errors.Errorf("Gather() requires a non-scalar operand, got %s", operand) } setCollapsedAxes := types.MakeSet[int]() for _, collapsedSliceAxis := range collapsedSliceAxes { if collapsedSliceAxis < 0 || collapsedSliceAxis >= operand.Rank() { - exceptions.Panicf("collapsed slice axis %d is out of range for operand %s", collapsedSliceAxis, operand) + return output, errors.Errorf("collapsed slice axis %d is out of range for operand %s", collapsedSliceAxis, operand) } if setCollapsedAxes.Has(collapsedSliceAxis) { - exceptions.Panicf("collapsed slice axis %d is defined more than once for operand %s", collapsedSliceAxis, operand) + return output, errors.Errorf("collapsed slice axis %d is defined more than once for operand %s", collapsedSliceAxis, operand) } setCollapsedAxes.Insert(collapsedSliceAxis) } // Check slice sizes. if len(sliceSizes) != operand.Rank() { - exceptions.Panicf("sliceSizes must have one value per operand axes, so it length (%d) must match operand rank (%d)", len(sliceSizes), operand.Rank()) + return output, errors.Errorf("sliceSizes must have one value per operand axes, so it length (%d) must match operand rank (%d)", len(sliceSizes), operand.Rank()) } for axis, sliceSize := range sliceSizes { if sliceSize < 0 { - exceptions.Panicf("sliceSize %d for axis %d is negative, it must be non-negative", sliceSize, axis) + return output, errors.Errorf("sliceSize %d for axis %d is negative, it must be non-negative", sliceSize, axis) } if operand.Dimensions[axis] < sliceSize { - exceptions.Panicf("sliceSize %d for axis %d is larger than the corresponding operand dimension %d", sliceSize, axis, operand.Dimensions[axis]) + return output, errors.Errorf("sliceSize %d for axis %d is larger than the corresponding operand dimension %d", sliceSize, axis, operand.Dimensions[axis]) } } for collapseAxis := range setCollapsedAxes { if sliceSizes[collapseAxis] != 1 { - exceptions.Panicf("collapsed slice axis %d must have sliceSize 1, but got %d", collapseAxis, sliceSizes[collapseAxis]) + return output, errors.Errorf("collapsed slice axis %d must have sliceSize 1, but got %d", collapseAxis, sliceSizes[collapseAxis]) } } if operand.Rank() != len(collapsedSliceAxes)+len(offsetOutputAxes) { - exceptions.Panicf("the number of collapsedSliceAxes (%d) + the number of offsetOutputAxes (%d) must be equal to the number of axes in the operand (operand.Rank()=%d)", + return output, errors.Errorf("the number of collapsedSliceAxes (%d) + the number of offsetOutputAxes (%d) must be equal to the number of axes in the operand (operand.Rank()=%d)", len(collapsedSliceAxes), len(offsetOutputAxes), operand.Rank()) } // Check indexVectorAxis: it is ok if it is equal to startIndices.rank, in which case we assume implicit extra axes of dimension 1. if indexVectorAxis < 0 || indexVectorAxis > operand.Rank() { - exceptions.Panicf("indexVectorAxis=%d is out of range for operand %s", indexVectorAxis, operand) + return output, errors.Errorf("indexVectorAxis=%d is out of range for operand %s", indexVectorAxis, operand) } // Check startIndexMap is set for the dimensions of indexVectorAxis in startIndices. if len(startIndexMap) != startIndices.Dimensions[indexVectorAxis] { - exceptions.Panicf("startIndexMap must have one value per dimension of indexVectorAxis, so it length (%d) must match startIndices.Dimensions[%d] (%d)", + return output, errors.Errorf("startIndexMap must have one value per dimension of indexVectorAxis, so it length (%d) must match startIndices.Dimensions[%d] (%d)", len(startIndexMap), indexVectorAxis, startIndices.Dimensions[indexVectorAxis]) } for idx, operandAxis := range startIndexMap { if operandAxis < 0 || operandAxis >= operand.Rank() { - exceptions.Panicf("startIndexMap[%d]=%d is out of range for operand %s", idx, operandAxis, operand) + return output, errors.Errorf("startIndexMap[%d]=%d is out of range for operand %s", idx, operandAxis, operand) } } @@ -555,16 +554,16 @@ func GatherOp(operand, startIndices shapes.Shape, indexVectorAxis int, offsetOut // // - Axes in offsetOutputAxes are preset as offset, and their dimensions are taken sequentially from non-collapsed operand axes. // - Remaining axes are filled in order from the batch axes, taken from startIndices. - output := shapes.Make(operand.DType) + output = shapes.Make(operand.DType) output.Dimensions = make([]int, batchRank+len(offsetOutputAxes)) setOffsetOutputAxes := types.MakeSet[int]() for _, offsetOutputAxis := range offsetOutputAxes { if offsetOutputAxis < 0 || offsetOutputAxis >= output.Rank() { - exceptions.Panicf("offset output axis %d is out of range for output of rank %d", offsetOutputAxis, output.Rank()) + return shapes.Invalid(), errors.Errorf("offset output axis %d is out of range for output of rank %d", offsetOutputAxis, output.Rank()) } if setOffsetOutputAxes.Has(offsetOutputAxis) { - exceptions.Panicf("offset output axis %d is defined more than once: offsetOutputAxes=%v", offsetOutputAxis, offsetOutputAxes) + return shapes.Invalid(), errors.Errorf("offset output axis %d is defined more than once: offsetOutputAxes=%v", offsetOutputAxis, offsetOutputAxes) } setOffsetOutputAxes.Insert(offsetOutputAxis) } @@ -592,44 +591,44 @@ func GatherOp(operand, startIndices shapes.Shape, indexVectorAxis int, offsetOut batchDimsIdx++ } } - return output + return output, nil } // ConcatenateOp calculates the output shape of a Concatenate operation. // It takes a slice of input shapes and the dimension along which to concatenate. -func ConcatenateOp(inputs []shapes.Shape, axis int) shapes.Shape { +func ConcatenateOp(inputs []shapes.Shape, axis int) (output shapes.Shape, err error) { if len(inputs) == 0 { - exceptions.Panicf("ConcatenateOp requires at least one input shape") + return shapes.Invalid(), errors.Errorf("ConcatenateOp requires at least one input shape") } // Initialize output dimensions with the first shape. firstShape := inputs[0] dtype := firstShape.DType rank := firstShape.Rank() - output := firstShape.Clone() + output = firstShape.Clone() if dtype == dtypes.InvalidDType { - exceptions.Panicf("invalid shape %s for first input of ConcatenateOp", firstShape) + return shapes.Invalid(), errors.Errorf("invalid shape %s for first input of ConcatenateOp", firstShape) } if len(inputs) == 1 { - return firstShape + return firstShape, nil } if axis < 0 || axis >= rank { - exceptions.Panicf("invalid concatenation axis %d for shapes with rank %d", axis, rank) + return shapes.Invalid(), errors.Errorf("invalid concatenation axis %d for shapes with rank %d", axis, rank) } // Validate further inputs and accumulate the concatenation axis size. for i := 1; i < len(inputs); i++ { currentShape := inputs[i] if currentShape.DType == dtypes.InvalidDType { - exceptions.Panicf("invalid shape %s for input #%d of ConcatenateOp", currentShape, i) + return shapes.Invalid(), errors.Errorf("invalid shape %s for input #%d of ConcatenateOp", currentShape, i) } if currentShape.DType != dtype { - exceptions.Panicf("mismatched DTypes for ConcatenateOp: input #0 has %s, input #%d has %s", + return shapes.Invalid(), errors.Errorf("mismatched DTypes for ConcatenateOp: input #0 has %s, input #%d has %s", dtype, i, currentShape.DType) } if currentShape.Rank() != rank { - exceptions.Panicf("mismatched ranks for ConcatenateOp: input #0 has rank %d, input #%d has rank %d", + return shapes.Invalid(), errors.Errorf("mismatched ranks for ConcatenateOp: input #0 has rank %d, input #%d has rank %d", rank, i, currentShape.Rank()) } @@ -638,33 +637,33 @@ func ConcatenateOp(inputs []shapes.Shape, axis int) shapes.Shape { output.Dimensions[d] += currentShape.Dimensions[d] } else { if currentShape.Dimensions[d] != output.Dimensions[d] { - exceptions.Panicf("mismatched dimensions for ConcatenateOp at axis %d (non-concatenation axis): input #0 has %d, input #%d has %d", + return shapes.Invalid(), errors.Errorf("mismatched dimensions for ConcatenateOp at axis %d (non-concatenation axis): input #0 has %d, input #%d has %d", d, output.Dimensions[d], i, currentShape.Dimensions[d]) } } } } - return output + return output, nil } // ScatterOp checks that the parameters are consistent. The output shape returned is the unchanged operand -- the scattered // updates are applied to the operand, but its shape is unchanged. // // The Scatter operations indicesAreSorted and uniqueIndices don't play a role in this. -func ScatterOp(operand, indices, updates shapes.Shape, indexVectorAxis int, updateWindowAxes, insertedWindowAxes, scatterAxesToOperandAxes []int) shapes.Shape { +func ScatterOp(operand, indices, updates shapes.Shape, indexVectorAxis int, updateWindowAxes, insertedWindowAxes, scatterAxesToOperandAxes []int) (output shapes.Shape, err error) { if operand.DType == dtypes.InvalidDType || indices.DType == dtypes.InvalidDType || updates.DType == dtypes.InvalidDType { - exceptions.Panicf("invalid shape for operand (%s), indices (%s) or updates (%s) for ScatterOp", operand, indices, updates) + return shapes.Invalid(), errors.Errorf("invalid shape for operand (%s), indices (%s) or updates (%s) for ScatterOp", operand, indices, updates) } if operand.DType != updates.DType { - exceptions.Panicf("data types (DType) for ScatterOp operand (%s) and updates (%s) must match", operand, updates) + return shapes.Invalid(), errors.Errorf("data types (DType) for ScatterOp operand (%s) and updates (%s) must match", operand, updates) } if !indices.DType.IsInt() { - exceptions.Panicf("indices DType (%s) must be an integer type", indices) + return shapes.Invalid(), errors.Errorf("indices DType (%s) must be an integer type", indices) } // Check indexVectorAxis and get scatter indices dimensions. if indexVectorAxis < 0 || indexVectorAxis > indices.Rank() { - exceptions.Panicf("indexVectorAxis=%d must be in range [0, indices.Rank()=%d]", indexVectorAxis, indices.Rank()) + return shapes.Invalid(), errors.Errorf("indexVectorAxis=%d must be in range [0, indices.Rank()=%d]", indexVectorAxis, indices.Rank()) } // Validate scatter axes mapping. @@ -673,17 +672,17 @@ func ScatterOp(operand, indices, updates shapes.Shape, indexVectorAxis int, upda numIndexedAxes = indices.Dimensions[indexVectorAxis] } if len(scatterAxesToOperandAxes) != numIndexedAxes { - exceptions.Panicf("scatterAxesToOperandAxes length (%d) must match the size of indices's indexVectorAxis dimension (%d)", + return shapes.Invalid(), errors.Errorf("scatterAxesToOperandAxes length (%d) must match the size of indices's indexVectorAxis dimension (%d)", len(scatterAxesToOperandAxes), indices.Dimensions[indexVectorAxis]) } for i, axis := range scatterAxesToOperandAxes { if axis < 0 || axis >= operand.Rank() { - exceptions.Panicf("scatterAxesToOperandAxes[%d]=%d must be in range [0, operand.Rank()=%d)", i, axis, operand.Rank()) + return shapes.Invalid(), errors.Errorf("scatterAxesToOperandAxes[%d]=%d must be in range [0, operand.Rank()=%d)", i, axis, operand.Rank()) } } for i, axis := range updateWindowAxes { if axis < 0 || axis >= updates.Rank() { - exceptions.Panicf("updateWindowAxes[%d]=%d must be in range [0, updates.Rank()=%d)", i, axis, updates.Rank()-1) + return shapes.Invalid(), errors.Errorf("updateWindowAxes[%d]=%d must be in range [0, updates.Rank()=%d)", i, axis, updates.Rank()-1) } } @@ -693,19 +692,19 @@ func ScatterOp(operand, indices, updates shapes.Shape, indexVectorAxis int, upda numBatchAxes++ } if len(updateWindowAxes)+numBatchAxes != updates.Rank() { - exceptions.Panicf("numBatchAxes (%d) + len(updateWindowAxes) (%d) must match updates.Rank() (%d), so it "+ + return shapes.Invalid(), errors.Errorf("numBatchAxes (%d) + len(updateWindowAxes) (%d) must match updates.Rank() (%d), so it "+ "can fully addressed -- where numBatchAxes=indices.Rank() - 1, or if indexVector == indices.Rank(), numBatchAxes=indices.Rank()", numBatchAxes, len(updateWindowAxes), updates.Rank()) } // Validate update window dimensions. if len(updateWindowAxes)+len(insertedWindowAxes) != operand.Rank() { - exceptions.Panicf("operand.Rank() (%d) must match len(updateWindowAxes)(%d)+len(insertedWindowAxes)(%d), so operand indices can be fully defined", + return shapes.Invalid(), errors.Errorf("operand.Rank() (%d) must match len(updateWindowAxes)(%d)+len(insertedWindowAxes)(%d), so operand indices can be fully defined", operand.Rank(), len(updateWindowAxes), len(insertedWindowAxes)) } for i, axis := range insertedWindowAxes { if axis < 0 || axis >= operand.Rank() { - exceptions.Panicf("insertedWindowAxes[%d]=%d must be in range [0, operand.Rank()=%d)", i, axis, operand.Rank()) + return shapes.Invalid(), errors.Errorf("insertedWindowAxes[%d]=%d must be in range [0, operand.Rank()=%d)", i, axis, operand.Rank()) } } @@ -720,34 +719,34 @@ func ScatterOp(operand, indices, updates shapes.Shape, indexVectorAxis int, upda for ii, updatesAxis := range updateWindowAxes { operandAxis := operandUpdatedWindowAxes[ii] if updates.Dimensions[updatesAxis] > operand.Dimensions[operandAxis] { - exceptions.Panicf("updates.Dimensions[axis=%d](%d) > operand.Dimensions[axis=%d](%d), updates won't fit into the operand", + return shapes.Invalid(), errors.Errorf("updates.Dimensions[axis=%d](%d) > operand.Dimensions[axis=%d](%d), updates won't fit into the operand", updatesAxis, updates.Dimensions[updatesAxis], operandAxis, operand.Dimensions[operandAxis]) } } - return operand + return operand, nil } // SliceOp calculates the output shape for a Slice operation. // It checks that starts, limits, and strides have the correct length (matching operand rank), // and that the slice parameters are valid for the operand's dimensions. // Strides must be positive. -func SliceOp(operand shapes.Shape, starts, limits, strides []int) shapes.Shape { +func SliceOp(operand shapes.Shape, starts, limits, strides []int) (output shapes.Shape, err error) { rank := operand.Rank() opName := "SliceOp" if operand.DType == dtypes.InvalidDType { - exceptions.Panicf("%s: invalid operand shape %s", opName, operand) + return shapes.Invalid(), errors.Errorf("%s: invalid operand shape %s", opName, operand) } if len(starts) != rank { - exceptions.Panicf("%s: len(starts)=%d, but operand rank is %d", opName, len(starts), rank) + return shapes.Invalid(), errors.Errorf("%s: len(starts)=%d, but operand rank is %d", opName, len(starts), rank) } if len(limits) != rank { - exceptions.Panicf("%s: len(limits)=%d, but operand rank is %d", opName, len(limits), rank) + return shapes.Invalid(), errors.Errorf("%s: len(limits)=%d, but operand rank is %d", opName, len(limits), rank) } if len(strides) != rank { - exceptions.Panicf("%s: len(strides)=%d, but operand rank is %d", opName, len(strides), rank) + return shapes.Invalid(), errors.Errorf("%s: len(strides)=%d, but operand rank is %d", opName, len(strides), rank) } - outputShape := shapes.Shape{ + output = shapes.Shape{ DType: operand.DType, Dimensions: make([]int, rank), } @@ -757,29 +756,29 @@ func SliceOp(operand shapes.Shape, starts, limits, strides []int) shapes.Shape { dimSize := operand.Dimensions[axis] if stride <= 0 { - exceptions.Panicf("%s: stride must be positive, but got stride[%d]=%d for operand shape %s", + return shapes.Invalid(), errors.Errorf("%s: stride must be positive, but got stride[%d]=%d for operand shape %s", opName, axis, stride, operand) } if start < 0 || start >= dimSize { - exceptions.Panicf("%s: start index %d is out of bounds for axis %d with size %d (operand shape %s)", + return shapes.Invalid(), errors.Errorf("%s: start index %d is out of bounds for axis %d with size %d (operand shape %s)", opName, start, axis, dimSize, operand) } // Limit can be equal to dimSize. if limit < start || limit > dimSize { - exceptions.Panicf("%s: limit index %d is out of bounds for axis %d (start=%d, size=%d, operand shape %s)", + return shapes.Invalid(), errors.Errorf("%s: limit index %d is out of bounds for axis %d (start=%d, size=%d, operand shape %s)", opName, limit, axis, start, dimSize, operand) } // The first one is always taken, so we use the ceiling of the division. outputDimSize := (limit - start + (stride - 1)) / stride - outputShape.Dimensions[axis] = outputDimSize + output.Dimensions[axis] = outputDimSize } - return outputShape + return output, nil } // ArgMinMaxOp calculates the output shape for an ArgMinMax operation. -// It will be the shape of the operand minus the reduce axis. +// It will be the shape of the operand minus the "reduce" axis. func ArgMinMaxOp(operand shapes.Shape, axis int, outputDType dtypes.DType) (output shapes.Shape, err error) { if !outputDType.IsInt() { err = errors.Errorf("ArgMinMax outputDType must be an integer type, got %s", outputDType) From ddea99ce077a27de0a877fb588b4e1823a4cc4ef Mon Sep 17 00:00:00 2001 From: Jan Pfeifer Date: Fri, 9 May 2025 08:07:41 +0200 Subject: [PATCH 4/4] Refactoring shapeinference to return errors. --- .../shapeinference/shapeinference_test.go | 143 +++++++++--------- 1 file changed, 69 insertions(+), 74 deletions(-) diff --git a/backends/shapeinference/shapeinference_test.go b/backends/shapeinference/shapeinference_test.go index 1cc7aaf7..c7f078dd 100644 --- a/backends/shapeinference/shapeinference_test.go +++ b/backends/shapeinference/shapeinference_test.go @@ -110,8 +110,9 @@ func TestGatherOp(t *testing.T) { collapsedSliceAxes := []int{0, 2} startIndexMap := []int{0, 2, 3} sliceSizes := []int{1, 3, 1, 1} - output := GatherOp(operand, startIndices, indexVectorAxis, + output, err := GatherOp(operand, startIndices, indexVectorAxis, offsetOutputAxes, collapsedSliceAxes, startIndexMap, sliceSizes, false) + require.NoError(t, err) fmt.Printf("\tTest 1: outputShape=%s\n", output) require.NoError(t, output.Check(F32, 3, 3, 2, 1)) @@ -123,7 +124,8 @@ func TestGatherOp(t *testing.T) { collapsedSliceAxes = []int{1, 2} startIndexMap = []int{1, 2, 3} sliceSizes = []int{3, 1, 1, 1} - output = GatherOp(operand, startIndices, indexVectorAxis, offsetOutputAxes, collapsedSliceAxes, startIndexMap, sliceSizes, true) + output, err = GatherOp(operand, startIndices, indexVectorAxis, offsetOutputAxes, collapsedSliceAxes, startIndexMap, sliceSizes, true) + require.NoError(t, err) fmt.Printf("\tTest 2: outputShape=%s\n", output) require.NoError(t, output.Check(F32, 7, 3, 1, 8)) @@ -135,7 +137,8 @@ func TestGatherOp(t *testing.T) { collapsedSliceAxes = []int{0} startIndexMap = []int{0} sliceSizes = []int{1, 16} - output = GatherOp(operand, startIndices, indexVectorAxis, offsetOutputAxes, collapsedSliceAxes, startIndexMap, sliceSizes, true) + output, err = GatherOp(operand, startIndices, indexVectorAxis, offsetOutputAxes, collapsedSliceAxes, startIndexMap, sliceSizes, true) + require.NoError(t, err) fmt.Printf("\tTest 3: outputShape=%s\n", output) require.NoError(t, output.Check(F32, 8, 16)) } @@ -154,7 +157,8 @@ func TestScatterOp(t *testing.T) { insertedWindowAxes1 := []int{0} scatterAxesToOperandAxes1 := []int{0} // Index coordinate vector element 0 maps to operand axis 0 expected1 := operand1 - output1 := ScatterOp(operand1, indices1, updates1, indexVectorAxis1, updateWindowAxes1, insertedWindowAxes1, scatterAxesToOperandAxes1) + output1, err := ScatterOp(operand1, indices1, updates1, indexVectorAxis1, updateWindowAxes1, insertedWindowAxes1, scatterAxesToOperandAxes1) + require.NoError(t, err) require.True(t, expected1.Equal(output1), "Valid Case 1 Failed: Expected %s, got %s", expected1, output1) // Case 2: Scattering into a higher-rank tensor @@ -170,7 +174,8 @@ func TestScatterOp(t *testing.T) { insertedWindowAxes2 := []int{0, 1} // Axis 0, 1 of operand are the dimensions determined by the indices[i,j,:] scatterAxesToOperandAxes2 := []int{0, 1} // index coord 0 -> operand axis 0, index coord 1 -> operand axis 1 expected2 := operand2 - output2 := ScatterOp(operand2, indices2, updates2, indexVectorAxis2, updateWindowAxes2, insertedWindowAxes2, scatterAxesToOperandAxes2) + output2, err := ScatterOp(operand2, indices2, updates2, indexVectorAxis2, updateWindowAxes2, insertedWindowAxes2, scatterAxesToOperandAxes2) + require.NoError(t, err) require.True(t, expected2.Equal(output2), "Valid Case 2 Failed: Expected %s, got %s", expected2, output2) // Case 3: Different indexVectorAxis @@ -183,7 +188,8 @@ func TestScatterOp(t *testing.T) { insertedWindowAxes3 := []int{1, 2} scatterAxesToOperandAxes3 := []int{1, 2} // indices are used for different axes in the operand this time. expected3 := operand2 // Still expect operand shape - output3 := ScatterOp(operand3, indices3, updates3, indexVectorAxis3, updateWindowAxes3, insertedWindowAxes3, scatterAxesToOperandAxes3) + output3, err := ScatterOp(operand3, indices3, updates3, indexVectorAxis3, updateWindowAxes3, insertedWindowAxes3, scatterAxesToOperandAxes3) + require.NoError(t, err) require.True(t, expected3.Equal(output3), "Valid Case 3 Failed (IndexVecAxis=1): Expected %s, got %s", expected3, output3) // Case 4: No insertedWindowAxes (scattering full slices) @@ -196,7 +202,8 @@ func TestScatterOp(t *testing.T) { insertedWindowAxes4 := []int{0} // No window axes in operand (index selects full slice - which is scalar here) scatterAxesToOperandAxes4 := []int{0} // Index coord 0 -> operand axis 0 expected4 := operand4 - output4 := ScatterOp(operand4, indices4, updates4, indexVectorAxis4, updateWindowAxes4, insertedWindowAxes4, scatterAxesToOperandAxes4) + output4, err := ScatterOp(operand4, indices4, updates4, indexVectorAxis4, updateWindowAxes4, insertedWindowAxes4, scatterAxesToOperandAxes4) + require.NoError(t, err) require.True(t, expected4.Equal(output4), "Valid Case 4 Failed (No Window): Expected %s, got %s", expected4, output4) // Case 5: rearranging the output axes: @@ -207,68 +214,52 @@ func TestScatterOp(t *testing.T) { updateWindowAxes5 := []int{0} insertedWindowAxes5 := []int{0, 2} scatterAxesToOperandAxes5 := []int{0, 2} - output5 := ScatterOp(operand5, indices5, updates5, indexVectorAxis5, updateWindowAxes5, insertedWindowAxes5, scatterAxesToOperandAxes5) + output5, err := ScatterOp(operand5, indices5, updates5, indexVectorAxis5, updateWindowAxes5, insertedWindowAxes5, scatterAxesToOperandAxes5) + require.NoError(t, err) require.True(t, operand5.Equal(output5), "Valid Case 5 Failed (No Window): Expected %s, got %s", operand5, output5) // --- Error Cases --- // Error Case 1: Mismatched DType (Operand vs Updates) - unchanged - require.Panics(t, func() { - ScatterOp(MS(F32, 4, 5), MS(I8, 2, 1), MS(I8, 2, 5), 1, []int{1}, []int{1}, []int{0}) - }, "Error Case 1 Failed: Mismatched operand/updates DType") + _, err = ScatterOp(MS(F32, 4, 5), MS(I8, 2, 1), MS(I8, 2, 5), 1, []int{1}, []int{1}, []int{0}) + require.Error(t, err, "Error Case 1 Failed: Mismatched operand/updates DType") // Error Case 2: Invalid DType for indices - unchanged - require.Panics(t, func() { - ScatterOp(MS(F32, 4, 5), MS(F32, 2, 1), MS(F32, 2, 5), 1, []int{1}, []int{1}, []int{0}) - }, "Error Case 2 Failed: Invalid indices DType") + _, err = ScatterOp(MS(F32, 4, 5), MS(F32, 2, 1), MS(F32, 2, 5), 1, []int{1}, []int{1}, []int{0}) + require.Error(t, err, "Error Case 2 Failed: Invalid indices DType") // Error Case 3: indexVectorAxis out of bounds - require.Panics(t, func() { - ScatterOp(operand1, indices1, updates1, 2, updateWindowAxes1, insertedWindowAxes1, scatterAxesToOperandAxes1) // indices1 rank 2, axis 2 invalid - }, "Error Case 3 Failed: indexVectorAxis out of bounds") - require.Panics(t, func() { - ScatterOp(operand1, indices1, updates1, -1, updateWindowAxes1, insertedWindowAxes1, scatterAxesToOperandAxes1) // Negative axis - }, "Error Case 3 Failed: negative indexVectorAxis") + _, err = ScatterOp(operand1, indices1, updates1, 2, updateWindowAxes1, insertedWindowAxes1, scatterAxesToOperandAxes1) // indices1 rank 2, axis 2 invalid + require.Error(t, err, "Error Case 3 Failed: indexVectorAxis out of bounds") + + _, err = ScatterOp(operand1, indices1, updates1, -1, updateWindowAxes1, insertedWindowAxes1, scatterAxesToOperandAxes1) // Negative axis + require.Error(t, err, "Error Case 3 Failed: negative indexVectorAxis") // Error Case 4: len(updateWindowAxes) != len(insertedWindowAxes) - require.Panics(t, func() { - ScatterOp(operand1, indices1, updates1, indexVectorAxis1, []int{1}, []int{1, 0}, scatterAxesToOperandAxes1) // inserted has len 2, update has len 1 - }, "Error Case 4 Failed: len(updateWindowAxes) != len(insertedWindowAxes)") + _, err = ScatterOp(operand1, indices1, updates1, indexVectorAxis1, []int{1}, []int{1, 0}, scatterAxesToOperandAxes1) // inserted has len 2, update has len 1 + require.Error(t, err, "Error Case 4 Failed: len(updateWindowAxes) != len(insertedWindowAxes)") // Error Case 5: len(scatterAxesToOperandAxes) != size of index vector dimension - require.Panics(t, func() { - // indices1 shape [2, 1], indexVectorAxis1=1 -> index vector size is 1 - ScatterOp(operand1, indices1, updates1, indexVectorAxis1, updateWindowAxes1, insertedWindowAxes1, []int{0, 1}) // scatterAxes has len 2, expected 1 - }, "Error Case 5 Failed: len(scatterAxesToOperandAxes) mismatch") - require.Panics(t, func() { - // indices2 shape [2, 3, 2], indexVectorAxis2=2 -> index vector size is 2 - ScatterOp(operand2, indices2, updates2, indexVectorAxis2, updateWindowAxes2, insertedWindowAxes2, []int{0}) // scatterAxes has len 1, expected 2 - }, "Error Case 5 Failed: len(scatterAxesToOperandAxes) mismatch") + _, err = ScatterOp(operand1, indices1, updates1, indexVectorAxis1, updateWindowAxes1, insertedWindowAxes1, []int{0, 1}) // scatterAxes has len 2, expected 1 + require.Error(t, err, "Error Case 5 Failed: len(scatterAxesToOperandAxes) mismatch") + _, err = ScatterOp(operand2, indices2, updates2, indexVectorAxis2, updateWindowAxes2, insertedWindowAxes2, []int{0}) // scatterAxes has len 1, expected 2 + require.Error(t, err, "Error Case 5 Failed: len(scatterAxesToOperandAxes) mismatch") // Error Case 6: Invalid axis index in updateWindowAxes - require.Panics(t, func() { - // updates1 shape [2, 5], rank 2 - ScatterOp(operand1, indices1, updates1, indexVectorAxis1, []int{2}, insertedWindowAxes1, scatterAxesToOperandAxes1) // axis 2 invalid for rank 2 updates - }, "Error Case 6 Failed: Invalid axis in updateWindowAxes") + _, err = ScatterOp(operand1, indices1, updates1, indexVectorAxis1, []int{2}, insertedWindowAxes1, scatterAxesToOperandAxes1) // axis 2 invalid for rank 2 updates + require.Error(t, err, "Error Case 6 Failed: Invalid axis in updateWindowAxes") // Error Case 7: Invalid axis index in insertedWindowAxes - require.Panics(t, func() { - // operand1 shape [4, 5], rank 2 - ScatterOp(operand1, indices1, updates1, indexVectorAxis1, updateWindowAxes1, []int{2}, scatterAxesToOperandAxes1) // axis 2 invalid for rank 2 operand - }, "Error Case 7 Failed: Invalid axis in insertedWindowAxes") + _, err = ScatterOp(operand1, indices1, updates1, indexVectorAxis1, updateWindowAxes1, []int{2}, scatterAxesToOperandAxes1) // axis 2 invalid for rank 2 operand + require.Error(t, err, "Error Case 7 Failed: Invalid axis in insertedWindowAxes") // Error Case 8: Invalid axis index in scatterAxesToOperandAxes - require.Panics(t, func() { - // operand1 shape [4, 5], rank 2 - ScatterOp(operand1, indices1, updates1, indexVectorAxis1, updateWindowAxes1, insertedWindowAxes1, []int{2}) // axis 2 invalid for rank 2 operand - }, "Error Case 8 Failed: Invalid axis in scatterAxesToOperandAxes") + _, err = ScatterOp(operand1, indices1, updates1, indexVectorAxis1, updateWindowAxes1, insertedWindowAxes1, []int{2}) // axis 2 invalid for rank 2 operand + require.Error(t, err, "Error Case 8 Failed: Invalid axis in scatterAxesToOperandAxes") // Error Case 9: Update dimension is larger than the corresponding dimension in the operand: - require.Panics(t, func() { - // operand1 shape [4, 5], rank 2 - ScatterOp(operand5, indices5, updates5, indexVectorAxis5, updateWindowAxes5, []int{0, 1}, scatterAxesToOperandAxes5) // axis 2 invalid for rank 2 operand - }, "Error Case 9 Failed: Update dimension is larger than the corresponding dimension in the operand") - + _, err = ScatterOp(operand5, indices5, updates5, indexVectorAxis5, updateWindowAxes5, []int{0, 1}, scatterAxesToOperandAxes5) // axis 2 invalid for rank 2 operand + require.Error(t, err, "Error Case 9 Failed: Update dimension is larger than the corresponding dimension in the operand") } func TestSliceOp(t *testing.T) { @@ -281,7 +272,8 @@ func TestSliceOp(t *testing.T) { limits1 := []int{8} strides1 := []int{1} expected1 := MS(F32, 6) - output1 := SliceOp(operand1, starts1, limits1, strides1) + output1, err := SliceOp(operand1, starts1, limits1, strides1) + require.NoError(t, err) require.True(t, expected1.Equal(output1), "%s Valid Case 1 Failed: Expected %s, got %s", opName, expected1, output1) // Case 2: 2D slice with stride 1 @@ -290,7 +282,8 @@ func TestSliceOp(t *testing.T) { limits2 := []int{4, 5} strides2 := []int{1, 1} expected2 := MS(I32, 3, 3) - output2 := SliceOp(operand2, starts2, limits2, strides2) + output2, err := SliceOp(operand2, starts2, limits2, strides2) + require.NoError(t, err) require.True(t, expected2.Equal(output2), "%s Valid Case 2 Failed: Expected %s, got %s", opName, expected2, output2) // Case 3: 3D slice with different strides @@ -302,7 +295,8 @@ func TestSliceOp(t *testing.T) { // Dim 1: (8-0)/3 = 2.66 -> 3 elements (indices 0, 3, 6) // Dim 2: (6-1)/1 = 5 -> 5 elements (indices 1, 2, 3, 4, 5) expected3 := MS(Bool, 5, 3, 5) - output3 := SliceOp(operand3, starts3, limits3, strides3) + output3, err := SliceOp(operand3, starts3, limits3, strides3) + require.NoError(t, err) require.True(t, expected3.Equal(output3), "%s Valid Case 3 Failed: Expected %s, got %s", opName, expected3, output3) // Case 4: Slice resulting in size 1 dimension @@ -311,7 +305,8 @@ func TestSliceOp(t *testing.T) { limits4 := []int{6} strides4 := []int{1} expected4 := MS(F32, 1) - output4 := SliceOp(operand4, starts4, limits4, strides4) + output4, err := SliceOp(operand4, starts4, limits4, strides4) + require.NoError(t, err) require.True(t, expected4.Equal(output4), "%s Valid Case 4 Failed: Expected %s, got %s", opName, expected4, output4) // Case 5: Slice taking full dimension with stride > 1 @@ -321,7 +316,8 @@ func TestSliceOp(t *testing.T) { strides5 := []int{2} // Dim 0: (7-0)/2 = 3.5 -> 4 elements (indices 0, 2, 4, 6) expected5 := MS(I8, 4) - output5 := SliceOp(operand5, starts5, limits5, strides5) + output5, err := SliceOp(operand5, starts5, limits5, strides5) + require.NoError(t, err) require.True(t, expected5.Equal(output5), "%s Valid Case 5 Failed: Expected %s, got %s", opName, expected5, output5) // --- Error Cases --- @@ -331,45 +327,44 @@ func TestSliceOp(t *testing.T) { validStrides := []int{1, 1} // Error 1: Invalid operand DType - require.Panics(t, func() { - SliceOp(shapes.Shape{DType: dtypes.InvalidDType, Dimensions: []int{10}}, []int{0}, []int{5}, []int{1}) - }, "%s Error Case 1 Failed: Invalid operand DType", opName) + _, err = SliceOp(shapes.Shape{DType: dtypes.InvalidDType, Dimensions: []int{10}}, []int{0}, []int{5}, []int{1}) + require.Error(t, err, "%s Error Case 1 Failed: Invalid operand DType", opName) // Error 2: Incorrect length for starts - require.Panics(t, func() { SliceOp(operand, []int{1}, validLimits, validStrides) }, - "%s Error Case 2 Failed: len(starts) != rank", opName) + _, err = SliceOp(operand, []int{1}, validLimits, validStrides) + require.Error(t, err, "%s Error Case 2 Failed: len(starts) != rank", opName) // Error 3: Incorrect length for limits - require.Panics(t, func() { SliceOp(operand, validStarts, []int{8}, validStrides) }, - "%s Error Case 3 Failed: len(limits) != rank", opName) + _, err = SliceOp(operand, validStarts, []int{8}, validStrides) + require.Error(t, err, "%s Error Case 3 Failed: len(limits) != rank", opName) // Error 4: Incorrect length for strides - require.Panics(t, func() { SliceOp(operand, validStarts, validLimits, []int{1}) }, - "%s Error Case 4 Failed: len(strides) != rank", opName) + _, err = SliceOp(operand, validStarts, validLimits, []int{1}) + require.Error(t, err, "%s Error Case 4 Failed: len(strides) != rank", opName) // Error 5: Zero stride - require.Panics(t, func() { SliceOp(operand, validStarts, validLimits, []int{1, 0}) }, - "%s Error Case 5 Failed: Zero stride", opName) + _, err = SliceOp(operand, validStarts, validLimits, []int{1, 0}) + require.Error(t, err, "%s Error Case 5 Failed: Zero stride", opName) // Error 6: Negative stride - require.Panics(t, func() { SliceOp(operand, validStarts, validLimits, []int{-1, 1}) }, - "%s Error Case 6 Failed: Negative stride", opName) + _, err = SliceOp(operand, validStarts, validLimits, []int{-1, 1}) + require.Error(t, err, "%s Error Case 6 Failed: Negative stride", opName) // Error 7: Start index < 0 - require.Panics(t, func() { SliceOp(operand, []int{-1, 1}, validLimits, validStrides) }, - "%s Error Case 7 Failed: Start < 0", opName) + _, err = SliceOp(operand, []int{-1, 1}, validLimits, validStrides) + require.Error(t, err, "%s Error Case 7 Failed: Start < 0", opName) // Error 8: Start index >= dimSize - require.Panics(t, func() { SliceOp(operand, []int{10, 1}, validLimits, validStrides) }, - "%s Error Case 8 Failed: Start >= dimSize", opName) + _, err = SliceOp(operand, []int{10, 1}, validLimits, validStrides) + require.Error(t, err, "%s Error Case 8 Failed: Start >= dimSize", opName) // Error 9: Limit index < start index - require.Panics(t, func() { SliceOp(operand, validStarts, []int{0, 4}, validStrides) }, // limit[0]=0 < start[0]=1 - "%s Error Case 9 Failed: Limit < Start", opName) + _, err = SliceOp(operand, validStarts, []int{0, 4}, validStrides) // limit[0]=0 < start[0]=1 + require.Error(t, err, "%s Error Case 9 Failed: Limit < Start", opName) // Error 10: Limit index > dimSize - require.Panics(t, func() { SliceOp(operand, validStarts, []int{8, 6}, validStrides) }, // limit[1]=6 > dimSize[1]=5 - "%s Error Case 10 Failed: Limit > dimSize", opName) + _, err = SliceOp(operand, validStarts, []int{8, 6}, validStrides) // limit[1]=6 > dimSize[1]=5 + require.Error(t, err, "%s Error Case 10 Failed: Limit > dimSize", opName) } func TestArgMinMaxOp(t *testing.T) {