Skip to content

Commit

Permalink
vperm: support lmul > 1 for vslideup/dn and vrgather
Browse files Browse the repository at this point in the history
  • Loading branch information
Ziyue-Zhang committed Jan 12, 2024
1 parent 27c756f commit 8959d36
Showing 1 changed file with 172 additions and 57 deletions.
229 changes: 172 additions & 57 deletions src/main/scala/yunsuan/vector/VectorPerm/Permutation.scala
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,85 @@ import chisel3.util._
import chisel3.util.experimental.decode.TruthTable
import scala.language.{existentials, postfixOps}
import yunsuan.vector._
import chisel3.util.experimental.decode.{QMCMinimizer, TruthTable, decoder}

class slideupVs2VdTable() extends Module {
// convert uop index of slide instruction to offset of vs2 and vd
val src = IO(Input(UInt(8.W)))
val outOffsetVs2 = IO(Output(UInt(3.W)))
val outOffsetVd = IO(Output(UInt(3.W)))
def compute_vs2_vd(lmul:Int, uopIdx:Int): (Int, Int) = {
for (i <- 0 until lmul) {
var prev = i * (i + 1) / 2
for (j <- 0 until i + 1) {
if (uopIdx == prev + j) {
return (i - j, i)
}
}
}
return (0, 0)
}
var combLmulUopIdx : Seq[(Int, Int, Int, Int)] = Seq()
for (lmul <- 0 until 4) {
for (uopIdx <- 0 until 36) {
var offset = compute_vs2_vd(1 << lmul, uopIdx)
var offsetVs2 = offset._1
var offsetVd = offset._2
combLmulUopIdx :+= (lmul, uopIdx, offsetVs2, offsetVd)
}
}
val out = decoder(QMCMinimizer, src, TruthTable(combLmulUopIdx.map {
case (lmul, uopIdx, offsetVs2, offsetVd) =>
(BitPat((lmul << 6 | uopIdx).U(8.W)), BitPat((offsetVs2 << 3 | offsetVd).U(6.W)))
}, BitPat.N(6)))
outOffsetVs2 := out(5, 3)
outOffsetVd := out(2, 0)
}

class slidednVs2VdTable() extends Module {
// convert uop index of slide instruction to offset of vs2 and vd
val src = IO(Input(UInt(8.W)))
val outOffsetVs2 = IO(Output(UInt(3.W)))
val outOffsetVd = IO(Output(UInt(3.W)))
val outIsFirst = IO(Output(Bool()))
def compute_vs2_vd(lmul:Int, uopIdx:Int): (Int, Int, Int) = {
var uopNum = lmul * (lmul + 1) / 2
for (i <- 0 until lmul) {
var prev = lmul * i - i * (i - 1) / 2
for (j <- 0 until lmul - i) {
if (uopIdx == prev + lmul - i - j - 1) {
return (j, i, if (j == lmul - i - 1) 1 else 0)
}
}
}
return (0, 0, 0)
}
var combLmulUopIdx : Seq[(Int, Int, Int, Int, Int)] = Seq()
for (lmul <- 0 until 4) {
for (uopIdx <- 0 until 36) {
var offset = compute_vs2_vd(1 << lmul, uopIdx)
var offsetVs2 = offset._1
var offsetVd = offset._2
var isFirst = offset._3
combLmulUopIdx :+= (lmul, uopIdx, offsetVs2, offsetVd, isFirst)
}
}
val out = decoder(QMCMinimizer, src, TruthTable(combLmulUopIdx.map {
case (lmul, uopIdx, offsetVs2, offsetVd, isFirst) =>
(BitPat((lmul << 6 | uopIdx).U(8.W)), BitPat((isFirst << 6 | offsetVs2 << 3 | offsetVd).U(7.W)))
}, BitPat.N(7)))
outIsFirst := out(6).asBool
outOffsetVs2 := out(5, 3)
outOffsetVd := out(2, 0)
}

class Permutation extends Module {
val VLEN = 128
val xLen = 64
val LaneWidth = 64
val NLanes = VLEN / 64
val vlenb = VLEN / 8
val vlenbWidth = log2Ceil(vlenb)
val io = IO(new Bundle {
val in = Flipped(ValidIO(new VPermInput))
val out = Output(new VIFuOutput)
Expand All @@ -29,6 +101,7 @@ class Permutation extends Module {
val ma = io.in.bits.info.ma
val ta = io.in.bits.info.ta
val vlmul = io.in.bits.info.vlmul
val lmul = Mux(vlmul > 4.U, 0.U, vlmul)
val vstart = io.in.bits.info.vstart
val vl = io.in.bits.info.vl
val uopIdx = io.in.bits.info.uopIdx
Expand All @@ -37,7 +110,6 @@ class Permutation extends Module {
val vsew = srcTypeVs2(1, 0)
val vsew_plus1 = Wire(UInt(3.W))
vsew_plus1 := Cat(0.U(1.W), ~vsew) + 1.U
val signed = srcTypeVs2(3, 2) === 1.U
val widen = vdType(1, 0) === (srcTypeVs2(1, 0) + 1.U)
val vsew_bytes = 1.U << vsew
val vsew_bits = 8.U << vsew
Expand Down Expand Up @@ -244,62 +316,96 @@ class Permutation extends Module {
dontTouch(compressed_res)

val base = Wire(UInt(7.W))
val vmask0 = Mux(vcompress, vs1, vmask)
val vmask1 = Mux(vcompress, vs1 >> ele_cnt, vmask >> ele_cnt)
val vmask0 = vmask
val vmask_uop = Wire(UInt(VLEN.W))
val vmask_byte_strb = Wire(Vec(vlenb, UInt(1.W)))
val vs1_bytes = VecInit(Seq.tabulate(vlenb)(i => vs1((i + 1) * 8 - 1, i * 8)))
val vs2_bytes = VecInit(Seq.tabulate(vlenb)(i => vs2((i + 1) * 8 - 1, i * 8)))
val emul = vlmul(1, 0)
val evl = Mux1H(Seq.tabulate(4)(i => (emul === i.U) -> (ele_cnt << i.U)))

val vslideupOffset = Module(new slideupVs2VdTable)
vslideupOffset.src := Cat(lmul, uopIdx)
val vslideupVs2Id = vslideupOffset.outOffsetVs2
val vslideupVd2Id = vslideupOffset.outOffsetVd

val vslidednOffset = Module(new slidednVs2VdTable)
vslidednOffset.src := Cat(lmul, uopIdx)
val vslidednVs2Id = vslidednOffset.outOffsetVs2
val vslidednVd2Id = vslidednOffset.outOffsetVd

val vrgatherVdId = Mux1H(Seq(
(vlmul === 0.U) -> 0.U,
(vlmul === 1.U) -> uopIdx(1),
(vlmul === 2.U) -> uopIdx(3, 2),
(vlmul === 3.U) -> uopIdx(5, 3),
))

val vrgatherVs2Id = Mux1H(Seq(
(vlmul === 0.U) -> 0.U,
(vlmul === 1.U) -> uopIdx(0),
(vlmul === 2.U) -> uopIdx(1, 0),
(vlmul === 3.U) -> uopIdx(2, 0),
))

val vrgather16_sew8VdId = Mux1H(Seq(
(vlmul === 0.U) -> 0.U,
(vlmul === 1.U) -> uopIdx(2),
(vlmul === 2.U) -> uopIdx(4, 3),
))

val vrgather16_sew8Vs2Id = Mux1H(Seq(
(vlmul === 0.U) -> 0.U,
(vlmul === 1.U) -> uopIdx(1),
(vlmul === 2.U) -> uopIdx(2, 1),
))

val vdId = Mux1H(Seq(
((vrgather && !vrgather16_sew8) || vrgather_vx) -> vrgatherVdId,
vrgather16_sew8 -> vrgather16_sew8VdId,
(vslideup) -> vslideupVd2Id,
(vslidedn) -> vslidednVd2Id,
))
val vs2Id = Mux1H(Seq(
((vrgather && !vrgather16_sew8) || vrgather_vx) -> vrgatherVs2Id,
vrgather16_sew8 -> vrgather16_sew8Vs2Id,
(vslideup) -> vslideupVs2Id,
(vslidedn) -> vslidednVs2Id,
))

dontTouch(vdId)
dontTouch(vs2Id)

val vslideup_vl = Wire(UInt(8.W))
vlRemain := vslideup_vl
when((vcompress && uopIdx(1)) ||
(vslideup && ((uopIdx === 1.U) || (uopIdx === 2.U))) ||
(vslidedn && (uopIdx === 2.U)) ||
(((vrgather && !vrgather16_sew8) || vrgather_vx) && (uopIdx >= 2.U)) ||
(vrgather16_sew8 && (uopIdx >= 4.U))
) {
vlRemain := Mux(vslideup_vl >= ele_cnt, vslideup_vl - ele_cnt, 0.U)
}.elsewhen(vslide1up) {
when(vslide1up) {
vlRemain := Mux(vl >= (uopIdx << vsew_plus1), vl - (uopIdx << vsew_plus1), 0.U)
}.elsewhen(vslide1dn) {
vlRemain := Mux(vl >= (uopIdx(5, 1) << vsew_plus1), vl - (uopIdx(5, 1) << vsew_plus1), 0.U)
}.otherwise {
vlRemain := Mux1H(Seq.tabulate(8)(i => (vdId === i.U) -> Mux(vslideup_vl >= (ele_cnt * i.U), vslideup_vl - (ele_cnt * i.U), 0.U)))
}

vmask_uop := vmask0
when((vcompress && uopIdx(1)) ||
(vslideup && ((uopIdx === 1.U) || (uopIdx === 2.U))) ||
(vslidedn && (uopIdx === 2.U)) ||
(((vrgather && !vrgather16_sew8) || vrgather_vx) && (uopIdx >= 2.U)) ||
(vrgather16_sew8 && (uopIdx >= 4.U))
) {
vmask_uop := vmask1
}.elsewhen(vslide1up) {
when(vslide1up) {
vmask_uop := vmask >> (uopIdx << vsew_plus1)
}.elsewhen(vslide1dn) {
vmask_uop := vmask >> (uopIdx(5, 1) << vsew_plus1)
}

when((vcompress && (uopIdx === 3.U)) ||
(vslideup && (uopIdx === 1.U)) ||
(vslidedn && (uopIdx === 0.U) && (vlmul === 1.U))
) {
base := vlenb.U
}.otherwise {
base := 0.U
vmask_uop := Mux1H(Seq.tabulate(8)(i => (vdId === i.U) -> (vmask >> (ele_cnt * i.U))))
}

base := Mux1H(Seq.tabulate(8)(i => (vs2Id === i.U) -> (vlenb * i).U))

for (i <- 0 until vlenb) {
when(i.U < vlRemainBytes) {
vmask_byte_strb(i) := vmask_uop(i) | (vm & !vcompress)
vmask_byte_strb(i) := vmask_uop(i) | vm
when(vsew === 1.U(3.W)) {
vmask_byte_strb(i) := vmask_uop(i / 2) | (vm & !vcompress)
vmask_byte_strb(i) := vmask_uop(i / 2) | vm
}.elsewhen(vsew === 2.U(3.W)) {
vmask_byte_strb(i) := vmask_uop(i / 4) | (vm & !vcompress)
vmask_byte_strb(i) := vmask_uop(i / 4) | vm
}.elsewhen(vsew === 3.U(3.W)) {
vmask_byte_strb(i) := vmask_uop(i / 8) | (vm & !vcompress)
vmask_byte_strb(i) := vmask_uop(i / 8) | vm
}
}.otherwise {
vmask_byte_strb(i) := 0.U
Expand All @@ -309,9 +415,14 @@ class Permutation extends Module {
// vrgather/vrgather16
val vlmax_bytes = Wire(UInt(5.W))
val vrgather_byte_sel = Wire(Vec(vlenb, UInt(64.W)))
val first_gather = (vlmul >= 4.U) || (vlmul === 0.U) || ((vlmul === 1.U) && (Mux(vrgather16_sew8, uopIdx(1), uopIdx(0)) === 0.U))
val vs2_bytes_min = Mux((vrgather16_sew8 && uopIdx(1)) || (((vrgather && !vrgather16_sew8) || vrgather_vx) && uopIdx(0)), vlenb.U, 0.U)
val vs2_bytes_max = Mux((vrgather16_sew8 && uopIdx(1)) || (((vrgather && !vrgather16_sew8) || vrgather_vx) && uopIdx(0)), Cat(vlenb.U, 0.U), vlmax_bytes)
val first_gather = (vlmul >= 4.U) || vs2Id === 0.U
val vs2_bytes_min = Mux1H(Seq.tabulate(8)(i => (vs2Id === i.U) -> (vlenb * i).U))
val vs2_bytes_max = Mux1H(Seq(
(vs2Id === 0.U) -> vlmax_bytes,
) ++ (1 until 8).map(i => (vs2Id === i.U) -> (vlenb * (i + 1)).U))

dontTouch(vs2_bytes_min)
dontTouch(vs2_bytes_max)
val vrgather_vd = Wire(Vec(vlenb, UInt(8.W)))

vlmax_bytes := vlenb.U
Expand Down Expand Up @@ -354,10 +465,14 @@ class Permutation extends Module {
vrgather_byte_sel(i) := Cat(vs1((i / 4 + 1) * 16 - 1, i / 4 * 16), 0.U(2.W)) + i.U % 4.U
}
}.elsewhen(srcTypeVs2(1, 0) === 3.U) {
when(uopIdx(1) === 0.U) {
when(uopIdx(1, 0) === 0.U) {
vrgather_byte_sel(i) := Cat(vs1((i / 8 + 1) * 16 - 1, (i / 8) * 16), 0.U(3.W)) + i.U % 8.U
}.otherwise {
}.elsewhen(uopIdx(1, 0) === 1.U) {
vrgather_byte_sel(i) := Cat(vs1((i / 8 + 1 + 2) * 16 - 1, (i / 8 + 2) * 16), 0.U(3.W)) + i.U % 8.U
}.elsewhen(uopIdx(1, 0) === 2.U) {
vrgather_byte_sel(i) := Cat(vs1((i / 8 + 1 + 4) * 16 - 1, (i / 8 + 4) * 16), 0.U(3.W)) + i.U % 8.U
}.otherwise {
vrgather_byte_sel(i) := Cat(vs1((i / 8 + 1 + 6) * 16 - 1, (i / 8 + 6) * 16), 0.U(3.W)) + i.U % 8.U
}
}
}.elsewhen(srcTypeVs1(1, 0) === 2.U) {
Expand Down Expand Up @@ -394,10 +509,14 @@ class Permutation extends Module {
vrgather_byte_sel(i) := Cat(vs1((i / 4 + 1) * 16 - 1, i / 4 * 16), 0.U(2.W)) + i.U % 4.U
}
}.elsewhen(srcTypeVs2(1, 0) === 3.U) {
when(uopIdx(1) === 0.U) {
when(uopIdx(1, 0) === 0.U) {
vrgather_byte_sel(i) := Cat(vs1((i / 8 + 1) * 16 - 1, (i / 8) * 16), 0.U(3.W)) + i.U % 8.U
}.otherwise {
}.elsewhen(uopIdx(1, 0) === 1.U) {
vrgather_byte_sel(i) := Cat(vs1((i / 8 + 1 + 2) * 16 - 1, (i / 8 + 2) * 16), 0.U(3.W)) + i.U % 8.U
}.elsewhen(uopIdx(1, 0) === 2.U) {
vrgather_byte_sel(i) := Cat(vs1((i / 8 + 1 + 4) * 16 - 1, (i / 8 + 4) * 16), 0.U(3.W)) + i.U % 8.U
}.otherwise {
vrgather_byte_sel(i) := Cat(vs1((i / 8 + 1 + 6) * 16 - 1, (i / 8 + 6) * 16), 0.U(3.W)) + i.U % 8.U
}
}
}.elsewhen(srcTypeVs1(1, 0) === 2.U) {
Expand All @@ -413,7 +532,7 @@ class Permutation extends Module {
vrgather_vd(i) := Mux(ma, "hff".U, old_vd((i + 1) * 8 - 1, i * 8))
when(vmask_byte_strb(i).asBool) {
when((vrgather_byte_sel(i) >= vs2_bytes_min) && (vrgather_byte_sel(i) < vs2_bytes_max)) {
vrgather_vd(i) := vs2_bytes(vrgather_byte_sel(i.U) - vs2_bytes_min)
vrgather_vd(i) := vs2_bytes((vrgather_byte_sel(i) - vs2_bytes_min)(vlenbWidth - 1, 0))
}.elsewhen(first_gather) {
vrgather_vd(i) := 0.U
}.otherwise {
Expand All @@ -431,18 +550,19 @@ class Permutation extends Module {
val vslidedn_vd = Wire(Vec(vlenb, UInt(8.W)))
val vslide1dn_vd_wo_rs1 = Wire(Vec(vlenb, UInt(8.W)))
val vslide1dn_vd_rs1 = Wire(UInt(VLEN.W))
val first_slidedn = vslidedn && (uopIdx === 0.U || uopIdx === 2.U)
val first_slidedn = vslidedn && vslidednOffset.outIsFirst
val load_rs1 = (((vlmul >= 4.U) || (vlmul === 0.U)) && (uopIdx === 0.U)) ||
((vlmul === 1.U) && (uopIdx === 2.U)) ||
((vlmul === 2.U) && (uopIdx === 6.U)) ||
(uopIdx === 14.U)
val vslide1dn_vd = Mux((load_rs1 || uopIdx(0)), VecInit(Seq.tabulate(vlenb)(i => vslide1dn_vd_rs1((i + 1) * 8 - 1, i * 8))), vslide1dn_vd_wo_rs1)
dontTouch(base)

for (i <- 0 until vlenb) {
vslideup_vd(i) := Mux(ma, "hff".U, old_vd(i * 8 + 7, i * 8))
when(vmask_byte_strb(i).asBool) {
when(((base + i.U) >= slide_bytes) && ((base + i.U - slide_bytes) < vlenb.U)) {
vslideup_vd(i) := vs2_bytes(base + i.U - slide_bytes)
when(((base +& i.U) >= slide_bytes) && ((base +& i.U - slide_bytes) < vlmax_bytes)) {
vslideup_vd(i) := vs2_bytes((base +& i.U - slide_bytes)(vlenbWidth - 1, 0))
}.otherwise {
vslideup_vd(i) := old_vd(i * 8 + 7, i * 8)
}
Expand All @@ -452,8 +572,8 @@ class Permutation extends Module {
for (i <- 0 until vlenb) {
vslidedn_vd(i) := Mux(ma, "hff".U, old_vd(i * 8 + 7, i * 8))
when(vmask_byte_strb(i).asBool) {
when(((i.U + slide_bytes) >= base) && ((i.U + slide_bytes - base) < vlmax_bytes)) {
vslidedn_vd(i) := vs2_bytes(i.U + slide_bytes - base)
when(((i.U +& slide_bytes) >= base) && ((i.U +& slide_bytes - base) < vlmax_bytes)) {
vslidedn_vd(i) := vs2_bytes((i.U +& slide_bytes - base)(vlenbWidth - 1, 0))
}.elsewhen(first_slidedn) {
vslidedn_vd(i) := 0.U
}.otherwise {
Expand All @@ -466,7 +586,7 @@ class Permutation extends Module {
vslide1up_vd(i) := Mux(ma, "hff".U, old_vd(i * 8 + 7, i * 8))
when(vslide1up && (vmask_byte_strb(i) === 1.U)) {
when((i.U < vsew_bytes)) {
vslide1up_vd(i) := vs1_bytes(vlenb.U - vsew_bytes + i.U)
vslide1up_vd(i) := vs1_bytes((vlenb.U - vsew_bytes + i.U)(vlenbWidth - 1, 0))
}.otherwise {
vslide1up_vd(i) := vs2_bytes(i.U - vsew_bytes)
}
Expand Down Expand Up @@ -498,17 +618,12 @@ class Permutation extends Module {

val vslideup_vstart = Mux(vslideup & (slide_ele > vstart), Mux(slide_ele > VLEN.U, VLEN.U, slide_ele), vstart)
vstartRemain := vslideup_vstart
when((vcompress && (uopIdx === 3.U)) ||
((vslideup) && ((uopIdx === 1.U) || (uopIdx === 2.U))) ||
((vslidedn) && (uopIdx === 2.U)) ||
(((vrgather && !vrgather16_sew8) || vrgather_vx) && (uopIdx >= 2.U)) ||
(vrgather16_sew8 && (uopIdx >= 4.U))
) {
vstartRemain := Mux(vslideup_vstart >= ele_cnt, vslideup_vstart - ele_cnt, 0.U)
}.elsewhen(vslide1up) {
when(vslide1up) {
vstartRemain := Mux(vstart >= (uopIdx << vsew_plus1), vstart - (uopIdx << vsew_plus1), 0.U)
}.elsewhen(vslide1dn) {
vstartRemain := Mux(vstart >= (uopIdx(5, 1) << vsew_plus1), vstart - (uopIdx(5, 1) << vsew_plus1), 0.U)
}.otherwise {
vstartRemain := Mux1H(Seq.tabulate(8)(i => (vdId === i.U) -> Mux(vslideup_vstart >= (ele_cnt * i.U), vslideup_vstart - (ele_cnt * i.U), 0.U)))
}

val vd_reg = RegInit(0.U(VLEN.W))
Expand All @@ -523,9 +638,9 @@ class Permutation extends Module {
vd_reg := Cat(vslidedn_vd.reverse)
}.elsewhen(vslide1dn && fire) {
vd_reg := Cat(vslide1dn_vd.reverse)
}.elsewhen((vrgather || vrgather_vx) && !(vrgather16_sew8 && ((vlmul === 0.U) || (vlmul === 1.U))) && fire) {
}.elsewhen((vrgather || vrgather_vx) && !(vrgather16_sew8) && fire) {
vd_reg := Cat(vrgather_vd.reverse)
}.elsewhen(vrgather16_sew8 && (vlmul === 0.U) || (vlmul === 1.U) && fire) {
}.elsewhen(vrgather16_sew8 && fire) {
when(uopIdx(0)) {
vd_reg := Cat(Cat(vrgather_vd.reverse)(VLEN - 1, VLEN / 2), old_vd(VLEN / 2 - 1, 0))
}.otherwise {
Expand Down Expand Up @@ -565,10 +680,10 @@ class Permutation extends Module {
val tail_bytes = Mux((vlRemainBytes_reg >= vlenb.U), 0.U, vlenb.U - vlRemainBytes_reg)
val tail_bits = Cat(tail_bytes, 0.U(3.W))
val vmask_tail_bits = Wire(UInt(VLEN.W))
vmask_tail_bits := Mux(is_vmvnr_reg, vd_mask, vd_mask >> tail_bits)
vmask_tail_bits := vd_mask >> tail_bits
val tail_old_vd = old_vd_reg & (~vmask_tail_bits)
val tail_ones_vd = ~vmask_tail_bits
val tail_vd = Mux(is_vmvnr_reg, 0.U, Mux(ta_reg, tail_ones_vd, tail_old_vd))
val tail_vd = Mux(ta_reg, tail_ones_vd, tail_old_vd)
val perm_tail_mask_vd = Wire(UInt(VLEN.W))

val vstart_bytes = Mux(vstartRemainBytes_reg >= vlenb.U, vlenb.U, vstartRemainBytes_reg)
Expand Down

0 comments on commit 8959d36

Please sign in to comment.