Skip to content

Commit

Permalink
Merge pull request #8 from kshyatt-aws/ksh/precedence
Browse files Browse the repository at this point in the history
fix operator precedence issues
  • Loading branch information
kshyatt-aws authored Nov 6, 2024
2 parents 91e6a42 + e841be2 commit f045887
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 35 deletions.
89 changes: 54 additions & 35 deletions src/Quasar.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ const first_letter = re"[A-Za-z_]" | unicode
const general_letter = first_letter | re"[0-9]"

const prefloat = re"[-+]?([0-9]+\.[0-9]*|[0-9]*\.[0-9]+)"
const integer = re"[-+]?[0-9]+"
const integer = re"[-]?[0-9]+"
const float = prefloat | ((prefloat | re"[-+]?[0-9]+") * re"[eE][-+]?[0-9]+")

const qasm_tokens = [
Expand Down Expand Up @@ -314,23 +314,17 @@ end
Base.length(s::SizedBitVector) = s.size
Base.size(s::SizedBitVector) = (s.size,)
Base.show(io::IO, s::SizedBitVector) = print(io, "SizedBitVector{$(s.size.args[end])}")
Base.iterate(s::SizedBitVector) = nothing
Base.iterate(s::SizedBitVector, ::Nothing) = nothing
struct SizedInt <: Integer
size::QasmExpression
SizedInt(size::QasmExpression) = new(size)
SizedInt(sint::SizedInt) = new(sint.size)
end
Base.iterate(s::SizedInt) = nothing
Base.iterate(s::SizedInt, ::Nothing) = nothing
Base.show(io::IO, s::SizedInt) = print(io, "SizedInt{$(s.size.args[end])}")
struct SizedUInt <: Unsigned
size::QasmExpression
SizedUInt(size::QasmExpression) = new(size)
SizedUInt(suint::SizedUInt) = new(suint.size)
end
Base.iterate(s::SizedUInt) = nothing
Base.iterate(s::SizedUInt, ::Nothing) = nothing
Base.show(io::IO, s::SizedUInt) = print(io, "SizedUInt{$(s.size.args[end])}")
struct SizedFloat <: AbstractFloat
size::QasmExpression
Expand All @@ -343,16 +337,12 @@ struct SizedAngle <: AbstractFloat
SizedAngle(size::QasmExpression) = new(size)
SizedAngle(sangle::SizedAngle) = new(sangle.size)
end
Base.iterate(s::SizedAngle) = nothing
Base.iterate(s::SizedAngle, ::Nothing) = nothing
Base.show(io::IO, s::SizedAngle) = print(io, "SizedAngle{$(s.size.args[end])}")
struct SizedComplex <: Number
size::QasmExpression
SizedComplex(size::QasmExpression) = new(size)
SizedComplex(scomplex::SizedComplex) = new(scomplex.size)
end
Base.iterate(s::SizedComplex) = nothing
Base.iterate(s::SizedComplex, ::Nothing) = nothing
Base.show(io::IO, s::SizedComplex) = print(io, "SizedComplex{$(s.size.args[end])}")

struct SizedArray{T,N} <: AbstractArray{T, N}
Expand All @@ -367,14 +357,14 @@ function SizedArray(eltype::QasmExpression, size::QasmExpression)
end
return SizedArray(eltype.args[1], arr_size)
end
Base.iterate(s::SizedArray) = nothing
Base.iterate(s::SizedArray, ::Nothing) = nothing
Base.show(io::IO, s::SizedArray{T, N}) where {T, N} = print(io, "SizedArray{$(sprint(show, s.type)), $N}")
Base.size(a::SizedArray{T, N}, dim::Int=0) where {T, N} = a.size[dim+1]

const SizedNumber = Union{SizedComplex, SizedAngle, SizedFloat, SizedInt, SizedUInt}
Base.iterate(::Union{SizedArray, SizedBitVector, SizedNumber}) = nothing
Base.iterate(::Union{SizedArray, SizedBitVector, SizedNumber}, ::Nothing) = nothing
if v"1.9" <= VERSION < v"1.11"
Base.Iterators.iterlength(s::Union{SizedNumber, SizedBitVector, SizedArray}) = -1
Base.Iterators.iterlength(::Union{SizedNumber, SizedBitVector, SizedArray}) = -1
end

function parse_classical_type(tokens, stack, start, qasm)
Expand Down Expand Up @@ -498,9 +488,8 @@ end

function extract_braced_block(tokens::Vector{Tuple{Int64, Int32, Token}}, stack, start, qasm)
bracket_opening = findfirst(triplet->triplet[end] == lbracket, tokens)
bracket_closing = findfirst(triplet->triplet[end] == rbracket, tokens)
isnothing(bracket_opening) && throw(QasmParseError("missing opening [ ", stack, start, qasm))
opener = popat!(tokens, bracket_opening)
popat!(tokens, bracket_opening)
openers_met = 1
closers_met = 0
braced_tokens = Tuple{Int64, Int32, Token}[]
Expand Down Expand Up @@ -534,7 +523,7 @@ end
function parse_bracketed_expression(tokens::Vector{Tuple{Int64, Int32, Token}}, stack, start, qasm)
interior_tokens = extract_braced_block(tokens, stack, start, qasm)
push!(interior_tokens, (-1, Int32(-1), semicolon))
return parse_expression(interior_tokens, stack, start, qasm)
return parse_list_expression(interior_tokens, stack, start, qasm)
end

function parse_paren_expression(tokens::Vector{Tuple{Int64, Int32, Token}}, stack, start, qasm)
Expand Down Expand Up @@ -653,6 +642,22 @@ function parse_gate_mods(tokens::Vector{Tuple{Int64, Int32, Token}}, stack, star
end
end

function _op_precedence(op::Symbol)::Int
op == Symbol("**") && return 0
op (:!, :~) && return 1
op (:*, :/, :%) && return 2
op (:+, :-) && return 3
op (Symbol("<<"), Symbol(">>")) && return 4
op (:<, Symbol("<="), :>, Symbol(">=")) && return 5
op (Symbol("!="), Symbol("==")) && return 6
op == :& && return 7
op == :^ && return 8
op == :| && return 9
op == Symbol("&&") && return 10
op == Symbol("||") && return 11
end
has_precedence(op_a::Symbol, op_b::Symbol) = _op_precedence(op_a) <= _op_precedence(op_b)

function parse_expression(tokens::Vector{Tuple{Int64, Int32, Token}}, stack, start, qasm)
start_token = popfirst!(tokens)
next_token = first(tokens)
Expand Down Expand Up @@ -708,10 +713,10 @@ function parse_expression(tokens::Vector{Tuple{Int64, Int32, Token}}, stack, sta
expr = QasmExpression(:unary_op, unary_op_symbol, next_expr)
elseif head(next_expr) == :binary_op
# replace first argument
left_hand_side = next_expr.args[2]::QasmExpression
left_hand_side = next_expr.args[2]::QasmExpression
new_left_hand_side = QasmExpression(:unary_op, unary_op_symbol, left_hand_side)
next_expr.args[2] = new_left_hand_side
expr = next_expr
next_expr.args[2] = new_left_hand_side
expr = next_expr
end
elseif next_token[end] == colon
start = token_name
Expand Down Expand Up @@ -755,7 +760,7 @@ function parse_expression(tokens::Vector{Tuple{Int64, Int32, Token}}, stack, sta
right_hand_side = QasmExpression(:measure, parse_expression(tokens, stack, start, qasm))
elseif next_token[end] == operator
unary_op_token = parse_identifier(popfirst!(tokens), qasm)
next_token = first(tokens)
next_token = first(tokens)
unary_right_hand_side = next_token[end] == lparen ? parse_paren_expression(tokens, stack, start, qasm) : parse_expression(tokens, stack, start, qasm)
right_hand_side = QasmExpression(:unary_op, Symbol(unary_op_token.args[1]::String), unary_right_hand_side)
else
Expand All @@ -764,29 +769,43 @@ function parse_expression(tokens::Vector{Tuple{Int64, Int32, Token}}, stack, sta
op_expr = QasmExpression(:binary_op, parse_assignment_op(op_token, qasm), token_name, right_hand_side)
expr = QasmExpression(:classical_assignment, op_expr)
elseif next_token[end] == operator
op_token = parse_identifier(popfirst!(tokens), qasm)
op_token = parse_identifier(popfirst!(tokens), qasm)
left_hand_side = token_name
right_hand_side = parse_expression(tokens, stack, start, qasm)::QasmExpression
expr = QasmExpression(:binary_op, Symbol(op_token.args[1]), token_name, right_hand_side)
op_symbol = Symbol(op_token.args[1])
raw_expr = QasmExpression(:binary_op, op_symbol, token_name, right_hand_side)
if head(right_hand_side) == :binary_op && has_precedence(op_symbol, right_hand_side.args[1])
rhs_op = right_hand_side.args[1]
new_lhs = QasmExpression(:binary_op, op_symbol, left_hand_side, right_hand_side.args[2])
new_rhs = right_hand_side.args[end]
raw_expr = QasmExpression(:binary_op, rhs_op, new_lhs, new_rhs)
end
if !isempty(tokens) && first(tokens)[end] == im_token
expr = QasmExpression(:binary_op, :*, raw_expr, QasmExpression(:complex_literal, im))
popfirst!(tokens)
else
expr = raw_expr
end
else # some kind of function or gate call
# either a gate call or function call
arguments = parse_arguments_list(tokens, stack, start, qasm)
next_token = first(tokens)
is_gphase::Bool = (token_name isa QasmExpression && head(token_name) == :identifier && token_name.args[1]::String == "gphase")::Bool
# this is a gate call with qubit targets
is_gate_call = next_token[end] == identifier || next_token[end] == hw_qubit || is_gphase
# this is a function call - unless it is gphase!
if (next_token[end] == semicolon && !is_gphase)
popfirst!(tokens)
expr = QasmExpression(:function_call, token_name, arguments)
elseif next_token[end] == operator # actually a binary op!
op_token = parse_identifier(popfirst!(tokens), qasm)
left_hand_side = QasmExpression(:function_call, token_name, arguments)
right_hand_side = parse_expression(tokens, stack, start, qasm)
expr = QasmExpression(:binary_op, Symbol(op_token.args[1]), left_hand_side, right_hand_side)
else # it's a gate call or gphase
if is_gate_call
target_expr = QasmExpression(:qubit_targets, parse_list_expression(tokens, stack, start, qasm))
expr = QasmExpression(:gate_call, token_name, arguments)
push!(expr, target_expr)
expr = QasmExpression(:gate_call, token_name, arguments, target_expr)
else # this is a function call
next_token[end] == semicolon && popfirst!(tokens)
if next_token[end] == operator
next_op_token = parse_identifier(popfirst!(tokens), qasm)
left_hand_side = QasmExpression(:function_call, token_name, arguments)
right_hand_side = parse_expression(tokens, stack, start, qasm)::QasmExpression
expr = QasmExpression(:binary_op, Symbol(next_op_token.args[1]), left_hand_side, right_hand_side)
else
expr = QasmExpression(:function_call, token_name, arguments)
end
end
end
return expr
Expand Down
17 changes: 17 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,23 @@ Quasar.builtin_gates[] = complex_builtin_gates
(type="gphase", arguments=InstructionArgument[2*π], targets=[0, 1], controls=[0=>0, 1=>1], exponent=1.0),
]
end
@testset "Operator precedence $expr -> $val" for (expr, val) in (("complex[float] a = 1/sqrt(2)+1/sqrt(2)im;", (1+im)/√2),
("float a = 2*3 - 4*5;", 6-20),
("float a = 2*(3 - 4)*5;", -10),
("float a = 1/2+4;", 4.5),
("int a = 2 + 3*4 - 5;", 14 - 5),
("complex[float] a = 2+1/3im;", 2-(im/3)),
("bool a = 1 << 2 == 5;", false),
("bool a = true && true || false;", true),
)
qasm = """
$expr
"""
parsed = parse_qasm(qasm)
visitor = QasmProgramVisitor()
visitor(parsed)
@test visitor.classical_defs["a"].val == val
end
@testset "Casting" begin
@testset "Casting to $to_type from $from_type" for (to_type, to_value) in (("bool", true),), (from_type, from_value) in (("int[32]", "32",),
("uint[16]", "1",),
Expand Down

0 comments on commit f045887

Please sign in to comment.