Skip to content

Commit

Permalink
Remove spurious symbol on lit error checks (#2417)
Browse files Browse the repository at this point in the history
Our repo uses `expected-error@+1` rather than `@expected-error@+1` for
majority of the tests, so let's consolidate it to one style.
  • Loading branch information
ghpvnist committed Jun 25, 2024
1 parent 59b7fd2 commit aad3d9d
Showing 1 changed file with 26 additions and 26 deletions.
52 changes: 26 additions & 26 deletions stablehlo/tests/ops_stablehlo.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1114,7 +1114,7 @@ func.func @dynamic_broadcast_in_dim_c5_input_mismatch_with_shape(%arg0: tensor<1
// -----

func.func @dynamic_broadcast_in_dim_c7_output_dimensions_negative_size(%arg0: tensor<4xf32>) -> tensor<3x4xf32> {
// @expected-error@+2 {{output shape [-1, 4] is incompatible with return type of operation 'tensor<3x4xf32>'}}
// expected-error@+2 {{output shape [-1, 4] is incompatible with return type of operation 'tensor<3x4xf32>'}}
%0 = stablehlo.constant dense<[-1, 4]> : tensor<2xi64>
%1 = stablehlo.dynamic_broadcast_in_dim %arg0, %0, dims = [1] : (tensor<4xf32>, tensor<2xi64>) -> tensor<3x4xf32>
return %1 : tensor<3x4xf32>
Expand All @@ -1123,7 +1123,7 @@ func.func @dynamic_broadcast_in_dim_c7_output_dimensions_negative_size(%arg0: te
// -----

func.func @dynamic_broadcast_in_dim_c7_output_dimensions_mismatching_size(%arg0: tensor<4xf32>) -> tensor<3x4xf32> {
// @expected-error@+2 {{output shape [1, 4] is incompatible with return type of operation 'tensor<3x4xf32>'}}
// expected-error@+2 {{output shape [1, 4] is incompatible with return type of operation 'tensor<3x4xf32>'}}
%0 = stablehlo.constant dense<[1, 4]> : tensor<2xi64>
%1 = stablehlo.dynamic_broadcast_in_dim %arg0, %0, dims = [1] : (tensor<4xf32>, tensor<2xi64>) -> tensor<3x4xf32>
return %1 : tensor<3x4xf32>
Expand Down Expand Up @@ -1235,7 +1235,7 @@ func.func @if(%pred : tensor<i1>, %branch_operand : tensor<2xf32>) -> tensor<2xf

func.func @if_c1(%pred : tensor<i1>, %branch_operand : tensor<f32>) -> tensor<f32> {
// expected-error@+2 {{failed to infer returned types}}
// @expected-error@+1 {{branch 0 must have 0 arguments, but found 1}}
// expected-error@+1 {{branch 0 must have 0 arguments, but found 1}}
%0 = "stablehlo.if"(%pred) ({
^bb0(%arg0: tensor<f32>):
"stablehlo.return"(%branch_operand) : (tensor<f32>) -> ()
Expand All @@ -1249,7 +1249,7 @@ func.func @if_c1(%pred : tensor<i1>, %branch_operand : tensor<f32>) -> tensor<f3

func.func @if_c1(%pred : tensor<i1>, %branch_operand : tensor<f32>) -> tensor<f32> {
// expected-error@+2 {{failed to infer returned types}}
// @expected-error@+1 {{branch 1 must have 0 arguments, but found 1}}
// expected-error@+1 {{branch 1 must have 0 arguments, but found 1}}
%0 = "stablehlo.if"(%pred) ({
"stablehlo.return"(%branch_operand) : (tensor<f32>) -> ()
}, {
Expand All @@ -1263,7 +1263,7 @@ func.func @if_c1(%pred : tensor<i1>, %branch_operand : tensor<f32>) -> tensor<f3

func.func @if_c2(%pred : tensor<i1>, %branch_operand : tensor<f32>) -> tensor<f32> {
// expected-error@+2 {{failed to infer returned types}}
// @expected-error@+1 {{branch 0 and branch 1 have mismatched return types: 'tensor<f32>', 'tensor<f32>' vs 'tensor<f32>'}}
// expected-error@+1 {{branch 0 and branch 1 have mismatched return types: 'tensor<f32>', 'tensor<f32>' vs 'tensor<f32>'}}
%0 = "stablehlo.if"(%pred) ({
"stablehlo.return"(%branch_operand, %branch_operand) : (tensor<f32>, tensor<f32>) -> ()
}, {
Expand All @@ -1276,7 +1276,7 @@ func.func @if_c2(%pred : tensor<i1>, %branch_operand : tensor<f32>) -> tensor<f3

func.func @if_c3(%pred : tensor<i1>, %branch_operand : tensor<f32>) -> tensor<i32> {
// expected-error@+2 {{failed to infer returned types}}
// @expected-error@+1 {{inferred type(s) 'tensor<f32>' are incompatible with return type(s) of operation 'tensor<i32>'}}
// expected-error@+1 {{inferred type(s) 'tensor<f32>' are incompatible with return type(s) of operation 'tensor<i32>'}}
%0 = "stablehlo.if"(%pred) ({
"stablehlo.return"(%branch_operand) : (tensor<f32>) -> ()
}, {
Expand Down Expand Up @@ -1313,7 +1313,7 @@ func.func @if_dynamic_op_result(%pred : tensor<i1>, %branch_operand: tensor<2xf3

func.func @if_i1(%pred : tensor<1xi1>, %branch_operand : tensor<f32>) -> tensor<f32> {
// expected-error@+2 {{failed to infer returned types}}
// @expected-error@+1 {{operand should be rank 0 tensor but got rank 1}}
// expected-error@+1 {{operand should be rank 0 tensor but got rank 1}}
%0 = "stablehlo.if"(%pred) ({
"stablehlo.return"(%branch_operand) : (tensor<f32>) -> ()
}, {
Expand All @@ -1338,7 +1338,7 @@ func.func @case(%index : tensor<i32>, %branch_operand : tensor<f32>) -> (tensor<

func.func @case_c1(%index : tensor<i32>, %branch_operand : tensor<2xf32>) -> tensor<2xf32> {
// expected-error@+2 {{failed to infer returned types}}
// @expected-error@+1 {{expect at least one branch}}
// expected-error@+1 {{expect at least one branch}}
%0 = "stablehlo.case"(%index) : (tensor<i32>) -> tensor<2xf32>
func.return %0 : tensor<2xf32>
}
Expand All @@ -1347,7 +1347,7 @@ func.func @case_c1(%index : tensor<i32>, %branch_operand : tensor<2xf32>) -> ten

func.func @case_c2(%index : tensor<i32>, %branch_operand : tensor<f32>) -> tensor<f32> {
// expected-error@+2 {{failed to infer returned types}}
// @expected-error@+1 {{branch 1 must have 0 arguments, but found 1}}
// expected-error@+1 {{branch 1 must have 0 arguments, but found 1}}
%0 = "stablehlo.case"(%index) ({
"stablehlo.return"(%branch_operand) : (tensor<f32>) -> ()
}, {
Expand Down Expand Up @@ -3258,7 +3258,7 @@ func.func @dynamic_pad_c2(
%arg: tensor<4xf64>, %padding_value: tensor<f64>,
%padding_low: tensor<2xi32>, %padding_high: tensor<2xi32>, %interior_padding: tensor<2xi32>
) {
// @expected-error@+1 {{padding operands size (2) must match operand rank (1)}}
// expected-error@+1 {{padding operands size (2) must match operand rank (1)}}
%0 = stablehlo.dynamic_pad %arg, %padding_value, %padding_low, %padding_high, %interior_padding
: (tensor<4xf64>, tensor<f64>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<?xf64>
func.return
Expand All @@ -3271,7 +3271,7 @@ func.func @dynamic_pad_c3(
%padding_low: tensor<1xi32>, %padding_high: tensor<1xi32>
) {
%interior_padding = stablehlo.constant dense<-1> : tensor<1xi32>
// @expected-error@+1 {{interior_padding must be non-negative, but got -1}}
// expected-error@+1 {{interior_padding must be non-negative, but got -1}}
%0 = stablehlo.dynamic_pad %arg, %padding_value, %padding_low, %padding_high, %interior_padding
: (tensor<4xf64>, tensor<f64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<?xf64>
func.return
Expand All @@ -3281,7 +3281,7 @@ func.func @dynamic_pad_c3(

func.func @dynamic_pad_c4(%arg: tensor<4xf64>, %padding_value: tensor<f64>) {
%padding = stablehlo.constant dense<1> : tensor<1xi32>
// @expected-error@+1 {{expected output dimension at index 0 to equal 9, but got 4}}
// expected-error@+1 {{expected output dimension at index 0 to equal 9, but got 4}}
%0 = stablehlo.dynamic_pad %arg, %padding_value, %padding, %padding, %padding
: (tensor<4xf64>, tensor<f64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<4xf64>
func.return
Expand Down Expand Up @@ -6098,7 +6098,7 @@ func.func @dynamic_iota_invalid_iota_dimension_too_big() -> tensor<?xf32> {
// -----

func.func @dynamic_iota_output_shape_negative_size() -> tensor<4xf32> {
// @expected-error@+2 {{output shape [-1] is incompatible with return type of operation 'tensor<4xf32>'}}
// expected-error@+2 {{output shape [-1] is incompatible with return type of operation 'tensor<4xf32>'}}
%0 = stablehlo.constant dense<[-1]> : tensor<1xi64>
%1 = stablehlo.dynamic_iota %0, dim = 0 : (tensor<1xi64>) -> tensor<4xf32>
func.return %1 : tensor<4xf32>
Expand All @@ -6107,7 +6107,7 @@ func.func @dynamic_iota_output_shape_negative_size() -> tensor<4xf32> {
// -----

func.func @dynamic_iota_output_shape_mismatching_size() -> tensor<4xf32> {
// @expected-error@+2 {{output shape [1] is incompatible with return type of operation 'tensor<4xf32>'}}
// expected-error@+2 {{output shape [1] is incompatible with return type of operation 'tensor<4xf32>'}}
%0 = stablehlo.constant dense<[1]> : tensor<1xi64>
%1 = stablehlo.dynamic_iota %0, dim = 0 : (tensor<1xi64>) -> tensor<4xf32>
func.return %1 : tensor<4xf32>
Expand Down Expand Up @@ -6151,7 +6151,7 @@ func.func @composite_generic(%arg0: tensor<f32>, %arg1: tensor<f32>) {

func.func @foo() { func.return }
func.func @composite_c1() {
// @expected-error@+1 {{name must be a valid namespaced op name}}
// expected-error@+1 {{name must be a valid namespaced op name}}
stablehlo.composite "foo" { decomposition = @foo } : () -> ()
func.return
}
Expand All @@ -6160,7 +6160,7 @@ func.func @composite_c1() {

func.func @foo() { func.return }
func.func @composite_c1() {
// @expected-error@+1 {{name must be a valid namespaced op name}}
// expected-error@+1 {{name must be a valid namespaced op name}}
stablehlo.composite "." { decomposition = @foo } : () -> ()
func.return
}
Expand All @@ -6169,7 +6169,7 @@ func.func @composite_c1() {

func.func @foo() { func.return }
func.func @composite_c1() {
// @expected-error@+1 {{name must be a valid namespaced op name}}
// expected-error@+1 {{name must be a valid namespaced op name}}
stablehlo.composite "foo." { decomposition = @foo } : () -> ()
func.return
}
Expand All @@ -6178,7 +6178,7 @@ func.func @composite_c1() {

func.func @foo() { func.return }
func.func @composite_c1() {
// @expected-error@+1 {{name must be a valid namespaced op name}}
// expected-error@+1 {{name must be a valid namespaced op name}}
stablehlo.composite ".foo" { decomposition = @foo } : () -> ()
func.return
}
Expand All @@ -6187,7 +6187,7 @@ func.func @composite_c1() {

func.func @foo() { func.return }
func.func @composite_c1() {
// @expected-error@+1 {{name must be a valid namespaced op name}}
// expected-error@+1 {{name must be a valid namespaced op name}}
stablehlo.composite "0.foo" { decomposition = @foo } : () -> ()
func.return
}
Expand All @@ -6196,7 +6196,7 @@ func.func @composite_c1() {

func.func @foo() { func.return }
func.func @composite_c1() {
// @expected-error@+1 {{name must be a valid namespaced op name}}
// expected-error@+1 {{name must be a valid namespaced op name}}
stablehlo.composite "foo.%" { decomposition = @foo } : () -> ()
func.return
}
Expand All @@ -6205,7 +6205,7 @@ func.func @composite_c1() {

func.func @foo() { func.return }
func.func @composite_c1() {
// @expected-error@+1 {{name must be a valid namespaced op name}}
// expected-error@+1 {{name must be a valid namespaced op name}}
stablehlo.composite "foo.foo.%" { decomposition = @foo } : () -> ()
func.return
}
Expand All @@ -6222,7 +6222,7 @@ func.func @composite_c1() {
// -----

func.func @composite_c2(%arg0: tensor<f32>) {
// @expected-error@+1 {{'nonexistent' does not reference a valid function}}
// expected-error@+1 {{'nonexistent' does not reference a valid function}}
%0 = stablehlo.composite "stablehlo.nonexistent" %arg0 {
decomposition = @nonexistent
} : (tensor<f32>) -> tensor<f32>
Expand All @@ -6237,7 +6237,7 @@ func.func @foo() -> !stablehlo.token {
}

func.func @composite_c3(%arg0: tensor<f32>) {
// @expected-error@+1 {{has 1 operand(s), but decomposition has 0}}
// expected-error@+1 {{has 1 operand(s), but decomposition has 0}}
%0 = stablehlo.composite "stablehlo.identity" %arg0 {
decomposition = @foo
} : (tensor<f32>) -> !stablehlo.token
Expand All @@ -6252,7 +6252,7 @@ func.func @foo(%arg0: tensor<f64>) -> !stablehlo.token {
}

func.func @composite_c3(%arg0: tensor<f32>) {
// @expected-error@+1 {{operand at index 0 has type 'tensor<f32>', but decomposition has type 'tensor<f64>'}}
// expected-error@+1 {{operand at index 0 has type 'tensor<f32>', but decomposition has type 'tensor<f64>'}}
%0 = stablehlo.composite "stablehlo.identity" %arg0 {
decomposition = @foo
} : (tensor<f32>) -> !stablehlo.token
Expand All @@ -6266,7 +6266,7 @@ func.func @foo(%arg0: !stablehlo.token) {
}

func.func @composite_c4(%arg0: !stablehlo.token) {
// @expected-error@+1 {{has 1 result(s), but decomposition has 0}}
// expected-error@+1 {{has 1 result(s), but decomposition has 0}}
%0 = stablehlo.composite "stablehlo.identity" %arg0 {
decomposition = @foo
} : (!stablehlo.token) -> tensor<f32>
Expand All @@ -6281,7 +6281,7 @@ func.func @foo(%arg0: !stablehlo.token) -> tensor<f64> {
}

func.func @composite_c4(%arg0: !stablehlo.token) {
// @expected-error@+1 {{result at index 0 has type 'tensor<f32>', but decomposition has type 'tensor<f64>'}}
// expected-error@+1 {{result at index 0 has type 'tensor<f32>', but decomposition has type 'tensor<f64>'}}
%0 = stablehlo.composite "stablehlo.identity" %arg0 {
decomposition = @foo
} : (!stablehlo.token) -> tensor<f32>
Expand Down

0 comments on commit aad3d9d

Please sign in to comment.