Skip to content

Commit

Permalink
[red-knot] More precise inference for chained boolean expressions (#1…
Browse files Browse the repository at this point in the history
…5089)

## Summary

Resolves #13632.

## Test Plan

Markdown tests.
  • Loading branch information
InSyncWithFoo authored Dec 22, 2024
1 parent 60e433c commit 3b27d5d
Show file tree
Hide file tree
Showing 4 changed files with 84 additions and 26 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,10 @@ class C:
def __lt__(self, other) -> C: ...

x = A() < B() < C()
reveal_type(x) # revealed: A | B
reveal_type(x) # revealed: A & ~AlwaysTruthy | B

y = 0 < 1 < A() < 3
reveal_type(y) # revealed: bool | A
reveal_type(y) # revealed: Literal[False] | A

z = 10 < 0 < A() < B() < C()
reveal_type(z) # revealed: Literal[False]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ def _(foo: str):
reveal_type(False or "z") # revealed: Literal["z"]
reveal_type(False or True) # revealed: Literal[True]
reveal_type(False or False) # revealed: Literal[False]
reveal_type(foo or False) # revealed: str | Literal[False]
reveal_type(foo or True) # revealed: str | Literal[True]
reveal_type(foo or False) # revealed: str & ~AlwaysFalsy | Literal[False]
reveal_type(foo or True) # revealed: str & ~AlwaysFalsy | Literal[True]
```

## AND
Expand All @@ -20,8 +20,8 @@ def _(foo: str):
def _(foo: str):
reveal_type(True and False) # revealed: Literal[False]
reveal_type(False and True) # revealed: Literal[False]
reveal_type(foo and False) # revealed: str | Literal[False]
reveal_type(foo and True) # revealed: str | Literal[True]
reveal_type(foo and False) # revealed: str & ~AlwaysTruthy | Literal[False]
reveal_type(foo and True) # revealed: str & ~AlwaysTruthy | Literal[True]
reveal_type("x" and "y" and "z") # revealed: Literal["z"]
reveal_type("x" and "y" and "") # revealed: Literal[""]
reveal_type("" and "y") # revealed: Literal[""]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -219,3 +219,51 @@ else:
# TODO: It should be A. We should improve UnionBuilder or IntersectionBuilder. (issue #15023)
reveal_type(y) # revealed: A & ~AlwaysTruthy | A & ~AlwaysFalsy
```

## Narrowing in chained boolean expressions

```py
from typing import Literal

class A: ...

def _(x: Literal[0, 1]):
reveal_type(x or A()) # revealed: Literal[1] | A
reveal_type(x and A()) # revealed: Literal[0] | A

def _(x: str):
reveal_type(x or A()) # revealed: str & ~AlwaysFalsy | A
reveal_type(x and A()) # revealed: str & ~AlwaysTruthy | A

def _(x: bool | str):
reveal_type(x or A()) # revealed: Literal[True] | str & ~AlwaysFalsy | A
reveal_type(x and A()) # revealed: Literal[False] | str & ~AlwaysTruthy | A

class Falsy:
def __bool__(self) -> Literal[False]: ...

class Truthy:
def __bool__(self) -> Literal[True]: ...

def _(x: Falsy | Truthy):
reveal_type(x or A()) # revealed: Truthy | A
reveal_type(x and A()) # revealed: Falsy | A

class MetaFalsy(type):
def __bool__(self) -> Literal[False]: ...

class MetaTruthy(type):
def __bool__(self) -> Literal[False]: ...

class FalsyClass(metaclass=MetaFalsy): ...
class TruthyClass(metaclass=MetaTruthy): ...

def _(x: type[FalsyClass] | type[TruthyClass]):
# TODO: Should be `type[TruthyClass] | A`
# revealed: type[FalsyClass] & ~AlwaysFalsy | type[TruthyClass] & ~AlwaysFalsy | A
reveal_type(x or A())

# TODO: Should be `type[FalsyClass] | A`
# revealed: type[FalsyClass] & ~AlwaysTruthy | type[TruthyClass] & ~AlwaysTruthy | A
reveal_type(x and A())
```
50 changes: 30 additions & 20 deletions crates/red_knot_python_semantic/src/types/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3582,27 +3582,37 @@ impl<'db> TypeInferenceBuilder<'db> {
n_values: usize,
) -> Type<'db> {
let mut done = false;
UnionType::from_elements(
db,
values.into_iter().enumerate().map(|(i, ty)| {
if done {
Type::Never
} else {
let is_last = i == n_values - 1;
match (ty.bool(db), is_last, op) {
(Truthiness::Ambiguous, _, _) => ty,
(Truthiness::AlwaysTrue, false, ast::BoolOp::And) => Type::Never,
(Truthiness::AlwaysFalse, false, ast::BoolOp::Or) => Type::Never,
(Truthiness::AlwaysFalse, _, ast::BoolOp::And)
| (Truthiness::AlwaysTrue, _, ast::BoolOp::Or) => {
done = true;
ty
}
(_, true, _) => ty,
}

let elements = values.into_iter().enumerate().map(|(i, ty)| {
if done {
return Type::Never;
}

let is_last = i == n_values - 1;

match (ty.bool(db), is_last, op) {
(Truthiness::AlwaysTrue, false, ast::BoolOp::And) => Type::Never,
(Truthiness::AlwaysFalse, false, ast::BoolOp::Or) => Type::Never,

(Truthiness::AlwaysFalse, _, ast::BoolOp::And)
| (Truthiness::AlwaysTrue, _, ast::BoolOp::Or) => {
done = true;
ty
}
}),
)

(Truthiness::Ambiguous, false, _) => IntersectionBuilder::new(db)
.add_positive(ty)
.add_negative(match op {
ast::BoolOp::And => Type::AlwaysTruthy,
ast::BoolOp::Or => Type::AlwaysFalsy,
})
.build(),

(_, true, _) => ty,
}
});

UnionType::from_elements(db, elements)
}

fn infer_compare_expression(&mut self, compare: &ast::ExprCompare) -> Type<'db> {
Expand Down

0 comments on commit 3b27d5d

Please sign in to comment.