diff --git a/README.md b/README.md index ca17c6f7a..d1fa9bdf8 100644 --- a/README.md +++ b/README.md @@ -5,6 +5,9 @@ We replace the full complex hand-crafted object detection pipeline with a Transf ![DETR](.github/DETR.png) +**New** +* [Gradio App](https://gradio.app/g/AK391/detr) + **What it is**. Unlike traditional computer vision techniques, DETR approaches object detection as a direct set prediction problem. It consists of a set-based global loss, which forces unique predictions via bipartite matching, and a Transformer encoder-decoder architecture. Given a fixed small set of learned object queries, DETR reasons about the relations of the objects and the global image context to directly output the final set of predictions in parallel. Due to this parallel nature, DETR is very fast and efficient. diff --git a/demo.py b/demo.py new file mode 100644 index 000000000..79c851d66 --- /dev/null +++ b/demo.py @@ -0,0 +1,111 @@ +from PIL import Image +import requests +import matplotlib.pyplot as plt +import torch +from torch import nn +from torchvision.models import resnet50 +import torchvision.transforms as T +torch.set_grad_enabled(False); +import gradio as gr +import io + +model = torch.hub.load('facebookresearch/detr', 'detr_resnet50', pretrained=True) + + +# COCO classes +CLASSES = [ + 'N/A', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', + 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A', + 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', + 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack', + 'umbrella', 'N/A', 'N/A', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', + 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', + 'skateboard', 'surfboard', 'tennis racket', 'bottle', 'N/A', 'wine glass', + 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', + 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', + 'chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table', 'N/A', + 'N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', + 'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A', + 'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', + 'toothbrush' +] + +# colors for visualization +COLORS = [[0.000, 0.447, 0.741], [0.850, 0.325, 0.098], [0.929, 0.694, 0.125], + [0.494, 0.184, 0.556], [0.466, 0.674, 0.188], [0.301, 0.745, 0.933]] + +# standard PyTorch mean-std input image normalization +transform = T.Compose([ + T.Resize(800), + T.ToTensor(), + T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) +]) + +# for output bounding box post-processing +def box_cxcywh_to_xyxy(x): + x_c, y_c, w, h = x.unbind(1) + b = [(x_c - 0.5 * w), (y_c - 0.5 * h), + (x_c + 0.5 * w), (y_c + 0.5 * h)] + return torch.stack(b, dim=1) + +def rescale_bboxes(out_bbox, size): + img_w, img_h = size + b = box_cxcywh_to_xyxy(out_bbox) + b = b * torch.tensor([img_w, img_h, img_w, img_h], dtype=torch.float32) + return b + +def fig2img(fig): + """Convert a Matplotlib figure to a PIL Image and return it""" + buf = io.BytesIO() + fig.savefig(buf) + buf.seek(0) + return Image.open(buf) + + +def plot_results(pil_img, prob, boxes): + plt.figure(figsize=(16,10)) + plt.imshow(pil_img) + ax = plt.gca() + colors = COLORS * 100 + for p, (xmin, ymin, xmax, ymax), c in zip(prob, boxes.tolist(), colors): + ax.add_patch(plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin, + fill=False, color=c, linewidth=3)) + cl = p.argmax() + text = f'{CLASSES[cl]}: {p[cl]:0.2f}' + ax.text(xmin, ymin, text, fontsize=15, + bbox=dict(facecolor='yellow', alpha=0.5)) + plt.axis('off') + return fig2img(plt) + + + +def detr(im): + # mean-std normalize the input image (batch-size: 1) + img = transform(im).unsqueeze(0) + + # propagate through the model + outputs = model(img) + + # keep only predictions with 0.7+ confidence + probas = outputs['pred_logits'].softmax(-1)[0, :, :-1] + keep = probas.max(-1).values > 0.9 + + # convert boxes from [0; 1] to image scales + bboxes_scaled = rescale_bboxes(outputs['pred_boxes'][0, keep], im.size) + return plot_results(im, probas[keep], bboxes_scaled) + + + +inputs = gr.inputs.Image(type='pil', label="Original Image") +outputs = gr.outputs.Image(type="pil",label="Output Image") + +examples = [ + ['horses.jpg'], + ['pandas.jpg'] +] + +title = "DETR" +description = "demo for Facebook DETR. To use it, simply upload your image, or click one of the examples to load them. Read more at the links below." +article = "
" + +gr.Interface(detr, inputs, outputs, title=title, description=description, article=article, examples=examples).launch() \ No newline at end of file diff --git a/horses.jpg b/horses.jpg new file mode 100644 index 000000000..31e0bc56c Binary files /dev/null and b/horses.jpg differ diff --git a/pandas.jpg b/pandas.jpg new file mode 100644 index 000000000..4f56cdb76 Binary files /dev/null and b/pandas.jpg differ diff --git a/requirements.txt b/requirements.txt index bb8f7823b..63124452d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,9 +1,11 @@ cython -git+https://github.com/cocodataset/cocoapi.git#subdirectory=PythonAPI&egg=pycocotools submitit torch>=1.5.0 torchvision>=0.6.0 -git+https://github.com/cocodataset/panopticapi.git#egg=panopticapi scipy onnx onnxruntime +Pillow +matplotlib +gradio +numpy