Skip to content

Commit

Permalink
Allow to choose to return original image along BiRefNet results or not
Browse files Browse the repository at this point in the history
  • Loading branch information
dimitribarbot committed Sep 1, 2024
1 parent d870198 commit df03b4b
Showing 1 changed file with 15 additions and 5 deletions.
20 changes: 15 additions & 5 deletions scripts/postprocessing_rembg.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ def ui(self):
with FormRow():
model = gr.Dropdown(label="Remove background model", choices=models, value="None", info="Choose a BiRefNet model. Each model gives a different result.")
resolution = gr.Textbox(label="Resolution", value="", placeholder="1024x1024", info="If left empty, it will take image size rounded to the nearest multiple of 32.")

with FormRow():
return_original = gr.Checkbox(label="Return original image", value=False)
return_foreground = gr.Checkbox(label="Return foreground", value=False)
return_edge_mask = gr.Checkbox(label="Return edge mask", value=False)

Expand All @@ -55,12 +58,13 @@ def ui(self):
"enable": enable,
"model": model,
"resolution": resolution,
"return_original": return_original,
"return_foreground": return_foreground,
"return_edge_mask": return_edge_mask,
"edge_mask_width": edge_mask_width,
}

def process(self, pp: scripts_postprocessing.PostprocessedImage, enable, model, resolution, return_foreground, return_edge_mask, edge_mask_width):
def process(self, pp: scripts_postprocessing.PostprocessedImage, enable, model, resolution, return_original, return_foreground, return_edge_mask, edge_mask_width):
if not enable:
return

Expand All @@ -77,10 +81,16 @@ def process(self, pp: scripts_postprocessing.PostprocessedImage, enable, model,
edge_mask_width=edge_mask_width
)

if output_image:
pp.extra_images.append(output_image)
if mask:
pp.extra_images.append(mask)
if return_original:
if output_image:
pp.extra_images.append(output_image)
if mask:
pp.extra_images.append(mask)
else:
pp.image = output_image or mask
if output_image:
pp.extra_images.append(mask)

if edge_mask:
pp.extra_images.append(edge_mask)

Expand Down

0 comments on commit df03b4b

Please sign in to comment.