diff --git a/src/types.jl b/src/types.jl index 73e86fe..9cb8764 100644 --- a/src/types.jl +++ b/src/types.jl @@ -27,35 +27,41 @@ struct SizedInt <: Integer SizedInt(sint::SizedInt) = new(sint.size) end Base.show(io::IO, s::SizedInt) = print(io, "SizedInt{$(s.size.args[end])}") +Base.:(==)(i1::SizedInt, i2::SizedInt) = i1.size == i2.size struct SizedUInt <: Unsigned size::QasmExpression SizedUInt(size::QasmExpression) = new(size) SizedUInt(suint::SizedUInt) = new(suint.size) end +Base.:(==)(ui1::SizedUInt, ui2::SizedUInt) = ui1.size == ui2.size Base.show(io::IO, s::SizedUInt) = print(io, "SizedUInt{$(s.size.args[end])}") struct SizedFloat <: AbstractFloat size::QasmExpression SizedFloat(size::QasmExpression) = new(size) SizedFloat(sfloat::SizedFloat) = new(sfloat.size) end +Base.:(==)(f1::SizedFloat, f2::SizedFloat) = f1.size == f2.size Base.show(io::IO, s::SizedFloat) = print(io, "SizedFloat{$(s.size.args[end])}") struct SizedAngle <: AbstractFloat size::QasmExpression SizedAngle(size::QasmExpression) = new(size) SizedAngle(sangle::SizedAngle) = new(sangle.size) end +Base.:(==)(a1::SizedAngle, a2::SizedAngle) = a1.size == a2.size Base.show(io::IO, s::SizedAngle) = print(io, "SizedAngle{$(s.size.args[end])}") struct SizedComplex <: Number size::QasmExpression SizedComplex(size::QasmExpression) = new(size) SizedComplex(scomplex::SizedComplex) = new(scomplex.size) end +Base.:(==)(c1::SizedComplex, c2::SizedComplex) = c1.size == c2.size Base.show(io::IO, s::SizedComplex) = print(io, "SizedComplex{$(s.size.args[end])}") struct SizedArray{T,N} <: AbstractArray{T, N} type::T size::NTuple{N, Int} end +Base.:(==)(a1::SizedArray{T, N}, a2::SizedArray{T, N}) where {T,N} = (a1.type == a2.type && a1.size == a2.size) function SizedArray(eltype::QasmExpression, size::QasmExpression) arr_size = if head(size) == :n_dims ntuple(i->0, size.args[1].args[1]) diff --git a/src/visitor.jl b/src/visitor.jl index 238d95a..701df17 100644 --- a/src/visitor.jl +++ b/src/visitor.jl @@ -634,7 +634,15 @@ function (v::AbstractVisitor)(program_expr::QasmExpression) qubit_mapping(v)["$alias_name[$qubit_i]"] = [alias_qubits[qubit_i+1]] end else # both classical - throw(QasmVisitorError("classical array concatenation not yet supported!")) + left_array = classical_defs(v)[name(concat_left)] + right_array = classical_defs(v)[name(concat_right)] + new_size = QasmExpression(:binary_op, :+, only(size(left_array.type)), only(size(right_array.type))) + if left_array.type isa SizedBitVector + classical_defs(v)[alias_name] = ClassicalVariable(alias_name, new_size, vcat(left_array.val, right_array.val), false) + else + left_array.type == right_array.type || throw(QasmVisitorError("only arrays of the same element type can be concatenated")) + classical_defs(v)[alias_name] = ClassicalVariable(alias_name, left_array.type, vcat(left_array.val, right_array.val), false) + end end elseif head(right_hand_side) == :identifier referent_name = name(right_hand_side) @@ -662,7 +670,12 @@ function (v::AbstractVisitor)(program_expr::QasmExpression) end else referent = classical_defs(v)[referent_name] - classical_defs(v)[alias_name] = ClassicalVariable(alias_name, referent.type, view(referent.val, v(right_hand_side.args[end]) .+ 1), referent.is_const) + ref_ixs = v(right_hand_side.args[end]) + ixs = map(ref_ixs) do ix + ix >= 0 && return ix + 1 + ix < 0 && return v(length(referent.type)) + 1 + ix + end + classical_defs(v)[alias_name] = ClassicalVariable(alias_name, referent.type, view(referent.val, ixs), referent.is_const) end end elseif head(program_expr) == :identifier diff --git a/test/runtests.jl b/test/runtests.jl index 93b5d6e..a58bb2e 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -250,26 +250,66 @@ Quasar.builtin_gates[] = complex_builtin_gates """ parsed = parse_qasm(qasm) visitor = QasmProgramVisitor() - @test_throws Quasar.QasmVisitorError("classical array concatenation not yet supported!") visitor(parsed) - #@test collect(visitor.classical_defs["concatenated"].val) == BitVector((true, false, false, true)) - #@test visitor.classical_defs["first"].val == true - #@test visitor.classical_defs["last"].val == true - #@test collect(visitor.classical_defs["new_cat"].val) == BitVector((true, false, false, true)) + visitor(parsed) + @test collect(visitor.classical_defs["concatenated"].val) == BitVector((true, false, false, true)) + @test only(visitor.classical_defs["first"].val) == true + @test only(visitor.classical_defs["last"].val) == true + @test collect(visitor.classical_defs["new_cat"].val) == BitVector((true, false, false, true)) + qasm = """ + bit[2] one = "01"; + bit[2] two = "10"; + // Aliased register of four bits + let concatenated = two ++ one; // "1001" + concatenated[0] = false; + """ + parsed = parse_qasm(qasm) + visitor = QasmProgramVisitor() + visitor(parsed) + @test collect(visitor.classical_defs["concatenated"].val) == BitVector((false, false, false, true)) + @test collect(visitor.classical_defs["two"].val) == BitVector((true, false)) # test that these are *references* qasm = """ bit[2] one = "01"; bit[2] two = "10"; + array[int[8], 2] int_one = [1, 2]; + array[int[8], 2] int_two = [2, 3]; + array[uint[8], 2] uint_one = [1, 2]; + array[uint[8], 2] uint_two = [2, 3]; + array[float[16], 2] float_one = [1.0, 2.0]; + array[float[16], 2] float_two = [2.0, 3.0]; + array[angle[16], 2] angle_one = [1.0, 2.0]; + array[angle[16], 2] angle_two = [2.0, 3.0]; + array[complex[float[16]], 2] complex_one = [1.0im, 2.0]; + array[complex[float[16]], 2] complex_two = [2.0, 3.0im]; // Aliased register of four bits let concatenated = one; // "01" // First bit in aliased qubit array let first = concatenated[0]; concatenated[1] = false; + let int_concatenated = int_one ++ int_two; + let uint_concatenated = uint_one ++ uint_two; + let float_concatenated = float_one ++ float_two; + let angle_concatenated = angle_one ++ angle_two; + let complex_concatenated = complex_one ++ complex_two; """ parsed = parse_qasm(qasm) visitor = QasmProgramVisitor() visitor(parsed) @test visitor.classical_defs["one"].val == BitVector((false, false)) @test only(visitor.classical_defs["first"].val) == false + @test visitor.classical_defs["int_concatenated"].val == [1, 2, 2, 3] + @test visitor.classical_defs["uint_concatenated"].val == [1, 2, 2, 3] + @test visitor.classical_defs["float_concatenated"].val == [1.0, 2.0, 2.0, 3.0] + @test visitor.classical_defs["angle_concatenated"].val == [1.0, 2.0, 2.0, 3.0] + @test visitor.classical_defs["complex_concatenated"].val == [1im, 2.0, 2.0, 3im] + qasm = """ + array[int[8], 2] one = [1, 1]; + array[int[32], 2] two = [0, 0]; + let concatenated = one ++ two; + """ + parsed = parse_qasm(qasm) + visitor = QasmProgramVisitor() + @test_throws Quasar.QasmVisitorError("only arrays of the same element type can be concatenated") visitor(parsed) end @testset "Randomized Benchmarking" begin qasm = """