From e841be2484f7a39714f7598854d61448736b5580 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Tue, 5 Nov 2024 16:11:32 -0500 Subject: [PATCH] fix operator precedence issues --- src/Quasar.jl | 89 +++++++++++++++++++++++++++++------------------- test/runtests.jl | 17 +++++++++ 2 files changed, 71 insertions(+), 35 deletions(-) diff --git a/src/Quasar.jl b/src/Quasar.jl index ccbc696..4dd826a 100644 --- a/src/Quasar.jl +++ b/src/Quasar.jl @@ -24,7 +24,7 @@ const first_letter = re"[A-Za-z_]" | unicode const general_letter = first_letter | re"[0-9]" const prefloat = re"[-+]?([0-9]+\.[0-9]*|[0-9]*\.[0-9]+)" -const integer = re"[-+]?[0-9]+" +const integer = re"[-]?[0-9]+" const float = prefloat | ((prefloat | re"[-+]?[0-9]+") * re"[eE][-+]?[0-9]+") const qasm_tokens = [ @@ -314,23 +314,17 @@ end Base.length(s::SizedBitVector) = s.size Base.size(s::SizedBitVector) = (s.size,) Base.show(io::IO, s::SizedBitVector) = print(io, "SizedBitVector{$(s.size.args[end])}") -Base.iterate(s::SizedBitVector) = nothing -Base.iterate(s::SizedBitVector, ::Nothing) = nothing struct SizedInt <: Integer size::QasmExpression SizedInt(size::QasmExpression) = new(size) SizedInt(sint::SizedInt) = new(sint.size) end -Base.iterate(s::SizedInt) = nothing -Base.iterate(s::SizedInt, ::Nothing) = nothing Base.show(io::IO, s::SizedInt) = print(io, "SizedInt{$(s.size.args[end])}") struct SizedUInt <: Unsigned size::QasmExpression SizedUInt(size::QasmExpression) = new(size) SizedUInt(suint::SizedUInt) = new(suint.size) end -Base.iterate(s::SizedUInt) = nothing -Base.iterate(s::SizedUInt, ::Nothing) = nothing Base.show(io::IO, s::SizedUInt) = print(io, "SizedUInt{$(s.size.args[end])}") struct SizedFloat <: AbstractFloat size::QasmExpression @@ -343,16 +337,12 @@ struct SizedAngle <: AbstractFloat SizedAngle(size::QasmExpression) = new(size) SizedAngle(sangle::SizedAngle) = new(sangle.size) end -Base.iterate(s::SizedAngle) = nothing -Base.iterate(s::SizedAngle, ::Nothing) = nothing Base.show(io::IO, s::SizedAngle) = print(io, "SizedAngle{$(s.size.args[end])}") struct SizedComplex <: Number size::QasmExpression SizedComplex(size::QasmExpression) = new(size) SizedComplex(scomplex::SizedComplex) = new(scomplex.size) end -Base.iterate(s::SizedComplex) = nothing -Base.iterate(s::SizedComplex, ::Nothing) = nothing Base.show(io::IO, s::SizedComplex) = print(io, "SizedComplex{$(s.size.args[end])}") struct SizedArray{T,N} <: AbstractArray{T, N} @@ -367,14 +357,14 @@ function SizedArray(eltype::QasmExpression, size::QasmExpression) end return SizedArray(eltype.args[1], arr_size) end -Base.iterate(s::SizedArray) = nothing -Base.iterate(s::SizedArray, ::Nothing) = nothing Base.show(io::IO, s::SizedArray{T, N}) where {T, N} = print(io, "SizedArray{$(sprint(show, s.type)), $N}") Base.size(a::SizedArray{T, N}, dim::Int=0) where {T, N} = a.size[dim+1] const SizedNumber = Union{SizedComplex, SizedAngle, SizedFloat, SizedInt, SizedUInt} +Base.iterate(::Union{SizedArray, SizedBitVector, SizedNumber}) = nothing +Base.iterate(::Union{SizedArray, SizedBitVector, SizedNumber}, ::Nothing) = nothing if v"1.9" <= VERSION < v"1.11" - Base.Iterators.iterlength(s::Union{SizedNumber, SizedBitVector, SizedArray}) = -1 + Base.Iterators.iterlength(::Union{SizedNumber, SizedBitVector, SizedArray}) = -1 end function parse_classical_type(tokens, stack, start, qasm) @@ -498,9 +488,8 @@ end function extract_braced_block(tokens::Vector{Tuple{Int64, Int32, Token}}, stack, start, qasm) bracket_opening = findfirst(triplet->triplet[end] == lbracket, tokens) - bracket_closing = findfirst(triplet->triplet[end] == rbracket, tokens) isnothing(bracket_opening) && throw(QasmParseError("missing opening [ ", stack, start, qasm)) - opener = popat!(tokens, bracket_opening) + popat!(tokens, bracket_opening) openers_met = 1 closers_met = 0 braced_tokens = Tuple{Int64, Int32, Token}[] @@ -534,7 +523,7 @@ end function parse_bracketed_expression(tokens::Vector{Tuple{Int64, Int32, Token}}, stack, start, qasm) interior_tokens = extract_braced_block(tokens, stack, start, qasm) push!(interior_tokens, (-1, Int32(-1), semicolon)) - return parse_expression(interior_tokens, stack, start, qasm) + return parse_list_expression(interior_tokens, stack, start, qasm) end function parse_paren_expression(tokens::Vector{Tuple{Int64, Int32, Token}}, stack, start, qasm) @@ -653,6 +642,22 @@ function parse_gate_mods(tokens::Vector{Tuple{Int64, Int32, Token}}, stack, star end end +function _op_precedence(op::Symbol)::Int + op == Symbol("**") && return 0 + op ∈ (:!, :~) && return 1 + op ∈ (:*, :/, :%) && return 2 + op ∈ (:+, :-) && return 3 + op ∈ (Symbol("<<"), Symbol(">>")) && return 4 + op ∈ (:<, Symbol("<="), :>, Symbol(">=")) && return 5 + op ∈ (Symbol("!="), Symbol("==")) && return 6 + op == :& && return 7 + op == :^ && return 8 + op == :| && return 9 + op == Symbol("&&") && return 10 + op == Symbol("||") && return 11 +end +has_precedence(op_a::Symbol, op_b::Symbol) = _op_precedence(op_a) <= _op_precedence(op_b) + function parse_expression(tokens::Vector{Tuple{Int64, Int32, Token}}, stack, start, qasm) start_token = popfirst!(tokens) next_token = first(tokens) @@ -708,10 +713,10 @@ function parse_expression(tokens::Vector{Tuple{Int64, Int32, Token}}, stack, sta expr = QasmExpression(:unary_op, unary_op_symbol, next_expr) elseif head(next_expr) == :binary_op # replace first argument - left_hand_side = next_expr.args[2]::QasmExpression + left_hand_side = next_expr.args[2]::QasmExpression new_left_hand_side = QasmExpression(:unary_op, unary_op_symbol, left_hand_side) - next_expr.args[2] = new_left_hand_side - expr = next_expr + next_expr.args[2] = new_left_hand_side + expr = next_expr end elseif next_token[end] == colon start = token_name @@ -755,7 +760,7 @@ function parse_expression(tokens::Vector{Tuple{Int64, Int32, Token}}, stack, sta right_hand_side = QasmExpression(:measure, parse_expression(tokens, stack, start, qasm)) elseif next_token[end] == operator unary_op_token = parse_identifier(popfirst!(tokens), qasm) - next_token = first(tokens) + next_token = first(tokens) unary_right_hand_side = next_token[end] == lparen ? parse_paren_expression(tokens, stack, start, qasm) : parse_expression(tokens, stack, start, qasm) right_hand_side = QasmExpression(:unary_op, Symbol(unary_op_token.args[1]::String), unary_right_hand_side) else @@ -764,9 +769,23 @@ function parse_expression(tokens::Vector{Tuple{Int64, Int32, Token}}, stack, sta op_expr = QasmExpression(:binary_op, parse_assignment_op(op_token, qasm), token_name, right_hand_side) expr = QasmExpression(:classical_assignment, op_expr) elseif next_token[end] == operator - op_token = parse_identifier(popfirst!(tokens), qasm) + op_token = parse_identifier(popfirst!(tokens), qasm) + left_hand_side = token_name right_hand_side = parse_expression(tokens, stack, start, qasm)::QasmExpression - expr = QasmExpression(:binary_op, Symbol(op_token.args[1]), token_name, right_hand_side) + op_symbol = Symbol(op_token.args[1]) + raw_expr = QasmExpression(:binary_op, op_symbol, token_name, right_hand_side) + if head(right_hand_side) == :binary_op && has_precedence(op_symbol, right_hand_side.args[1]) + rhs_op = right_hand_side.args[1] + new_lhs = QasmExpression(:binary_op, op_symbol, left_hand_side, right_hand_side.args[2]) + new_rhs = right_hand_side.args[end] + raw_expr = QasmExpression(:binary_op, rhs_op, new_lhs, new_rhs) + end + if !isempty(tokens) && first(tokens)[end] == im_token + expr = QasmExpression(:binary_op, :*, raw_expr, QasmExpression(:complex_literal, im)) + popfirst!(tokens) + else + expr = raw_expr + end else # some kind of function or gate call # either a gate call or function call arguments = parse_arguments_list(tokens, stack, start, qasm) @@ -774,19 +793,19 @@ function parse_expression(tokens::Vector{Tuple{Int64, Int32, Token}}, stack, sta is_gphase::Bool = (token_name isa QasmExpression && head(token_name) == :identifier && token_name.args[1]::String == "gphase")::Bool # this is a gate call with qubit targets is_gate_call = next_token[end] == identifier || next_token[end] == hw_qubit || is_gphase - # this is a function call - unless it is gphase! - if (next_token[end] == semicolon && !is_gphase) - popfirst!(tokens) - expr = QasmExpression(:function_call, token_name, arguments) - elseif next_token[end] == operator # actually a binary op! - op_token = parse_identifier(popfirst!(tokens), qasm) - left_hand_side = QasmExpression(:function_call, token_name, arguments) - right_hand_side = parse_expression(tokens, stack, start, qasm) - expr = QasmExpression(:binary_op, Symbol(op_token.args[1]), left_hand_side, right_hand_side) - else # it's a gate call or gphase + if is_gate_call target_expr = QasmExpression(:qubit_targets, parse_list_expression(tokens, stack, start, qasm)) - expr = QasmExpression(:gate_call, token_name, arguments) - push!(expr, target_expr) + expr = QasmExpression(:gate_call, token_name, arguments, target_expr) + else # this is a function call + next_token[end] == semicolon && popfirst!(tokens) + if next_token[end] == operator + next_op_token = parse_identifier(popfirst!(tokens), qasm) + left_hand_side = QasmExpression(:function_call, token_name, arguments) + right_hand_side = parse_expression(tokens, stack, start, qasm)::QasmExpression + expr = QasmExpression(:binary_op, Symbol(next_op_token.args[1]), left_hand_side, right_hand_side) + else + expr = QasmExpression(:function_call, token_name, arguments) + end end end return expr diff --git a/test/runtests.jl b/test/runtests.jl index 84b5f53..e5bf27f 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -289,6 +289,23 @@ Quasar.builtin_gates[] = complex_builtin_gates (type="gphase", arguments=InstructionArgument[2*π], targets=[0, 1], controls=[0=>0, 1=>1], exponent=1.0), ] end + @testset "Operator precedence $expr -> $val" for (expr, val) in (("complex[float] a = 1/sqrt(2)+1/sqrt(2)im;", (1+im)/√2), + ("float a = 2*3 - 4*5;", 6-20), + ("float a = 2*(3 - 4)*5;", -10), + ("float a = 1/2+4;", 4.5), + ("int a = 2 + 3*4 - 5;", 14 - 5), + ("complex[float] a = 2+1/3im;", 2-(im/3)), + ("bool a = 1 << 2 == 5;", false), + ("bool a = true && true || false;", true), + ) + qasm = """ + $expr + """ + parsed = parse_qasm(qasm) + visitor = QasmProgramVisitor() + visitor(parsed) + @test visitor.classical_defs["a"].val == val + end @testset "Casting" begin @testset "Casting to $to_type from $from_type" for (to_type, to_value) in (("bool", true),), (from_type, from_value) in (("int[32]", "32",), ("uint[16]", "1",),