Skip to content

Commit

Permalink
fix: rtmdet infer
Browse files Browse the repository at this point in the history
  • Loading branch information
nullptr committed Dec 27, 2024
1 parent 2d49d9b commit 6f9db5b
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 6 deletions.
3 changes: 1 addition & 2 deletions sscma/deploy/backend/tflite_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ def infer(self, input_data):
data = (data / scale + zero_point).astype(
input["dtype"]
) # de-scale

self.interpreter.set_tensor(input["index"], data)
self.interpreter.invoke()
y = []
Expand All @@ -49,7 +48,7 @@ def infer(self, input_data):
scale, zero_point = output["quantization"]
x = (x.astype(np.float32) - zero_point) * scale # re-scale
# numpy x convert NHWC to NCWH
y.append(np.transpose(x, [0, 3, 1, 2]))
y.append(np.transpose(x, [0, 3, 1, 2]) if len(x.shape) == 4 else x)

results.append(y)
return results
Expand Down
13 changes: 9 additions & 4 deletions sscma/deploy/models/rtmdet_infer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import copy
import math
import warnings
from typing import Dict, List, Tuple, Union

Expand Down Expand Up @@ -104,13 +105,17 @@ def _predict(
for dt in data_tmp:
tmp = [None for _ in range(6)]
for d in dt:
if d.shape[2:] in featmap_size:
if d.shape[1] == 4:
tmp[3 + featmap_size.index(d.shape[2:])] = d
fs = int(math.sqrt(d.shape[1]))
ts = (fs, fs)
if ts in featmap_size:
if d.shape[2] == 4:
tmp[3 + featmap_size.index(ts)] = d
else:
tmp[featmap_size.index(d.shape[2:])] = d
tmp[featmap_size.index(ts)] = d

data.append(tmp)


for result, data_sample in zip(data, batch_data_samples):
# check item in result is tensor or numpy

Expand Down

0 comments on commit 6f9db5b

Please sign in to comment.