8000 feat: add ffs inference code allowing arbitrary input size by ding3820 · Pull Request #5 · aim-uofa/SINE · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

feat: add ffs inference code allowing arbitrary input size #5

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
260 changes: 260 additions & 0 deletions demo_fss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,260 @@
import argparse
import os
import torch
from PIL import Image
import numpy as np
from torchvision import transforms
from detectron2.structures import Instances, BitMasks
import torch.nn.functional as F
from torchvision.transforms.functional import pil_to_tensor
from tqdm import tqdm

# Import build_model from the inference_fss module
from inference_fss.model.model import build_model


def pad_img(x, pad_size_h, pad_size_w):
"""Pad an image or mask tensor to the specified height and width."""
assert isinstance(x, torch.Tensor)
h, w = x.shape[-2:]
padh = pad_size_h - h
padw = pad_size_w - w
return F.pad(x, (0, padw, 0, padh))


def preprocess_image_and_mask(
img_path, mask_path, args, device, class_id=0, instance_id=0
):
"""Load and preprocess an image and optionally its mask, returning a dict for model input."""
# Define image transformation
encoder_transform = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
]
)

# Load image
img = Image.open(img_path).convert("RGB")
img_tensor = encoder_transform(img).to(device)
original_shape = img_tensor.shape[-2:] # (212, 1152)
img_tensor = pad_img(img_tensor, args.pad_size_h, args.pad_size_w)

# Base dictionary
data_dict = {
"image": img_tensor,
"height": original_shape[0],
"width": original_shape[1],
"original_img": img, # Store original PIL image for overlay
}

# Process mask if provided
if mask_path:
mask = pil_to_tensor(Image.open(mask_path).convert("L")).long().to(device)
mask = pad_img(mask, args.pad_size_h, args.pad_size_w) # (1, H, W)
else:
mask = torch.zeros_like(img_tensor[:1]) # (1, H, W)

# Create Instances object
instances = Instances(original_shape)
instances.gt_classes = torch.tensor([class_id], device=device)
mask = BitMasks(mask)
instances.gt_masks = mask.tensor
instances.gt_boxes = mask.get_bounding_boxes()
instances.ins_ids = torch.tensor([instance_id], device=device)
data_dict["instances"] = instances

return data_dict


def load_support_data(support_img_paths, support_mask_paths, args, device):
"""Load and preprocess all support images and masks."""
return [
preprocess_image_and_mask(
img_path, mask_path, args, device, class_id=0, instance_id=0
)
for img_path, mask_path in zip(support_img_paths, support_mask_paths)
]


def parse_data(support_imgs, support_masks, query_imgs):
"""Parse and examine file paths for support and query sets."""

for paths, name in [
(support_imgs, "support images"),
< 10000 span class='blob-code-inner blob-code-marker ' data-code-marker="+"> (support_masks, "support masks"),
(query_imgs, "query images"),
]:
assert all(os.path.exists(p) for p in paths), f"Some {name} are missing"

return support_imgs, support_masks, query_imgs


def overlay_mask_on_image(original_img, pred_mask, color=(173, 216, 230), alpha=0.5):
"""Overlay a binary mask on the original image with specified color and transparency."""
# Convert to PIL images
image = original_img.convert("RGBA")
mask = pred_mask.convert("L")

# Create a color overlay with the given color and transparency
overlay = Image.new("RGBA", image.size, color + (0,))
overlay.putalpha(mask.point(lambda p: p * alpha))

# Composite the overlay onto the original image
overlayed = Image.alpha_composite(image, overlay)

return overlayed


def main():
# Parse command-line arguments
parser = argparse.ArgumentParser(
description="Batch Inference with Overlay for Few-Shot Segmentation"
)
parser.add_argument(
"--support_img_paths",
type=str,
nargs="+",
required=True,
help="Paths to support images",
)
parser.add_argument(
"--support_mask_paths",
type=str,
nargs="+",
required=True,
help="Paths to support masks corresponding to support images",
)
parser.add_argument(
"--query_img_path",
type=str,
nargs="+",
required=True,
help="Path to the query image to segment",
)
parser.add_argument(
"--output_dir",
type=str,
default="outputs",
help="Directory to save overlaid images",
)
parser.add_argument(
"--model_weights",
type=str,
default="checkpoints/pytorch_model.bin",
help="Path to the pre-trained model weights",
)
parser.add_argument(
"--pad_size_h",
type=int,
default=896,
help="Height to pad images to",
)
parser.add_argument(
"--pad_size_w",
type=int,
default=896,
help="Width to pad images to",
)
parser.add_argument(
"--device", default="cuda", help="Device to run inference on (cuda or cpu)"
)
parser.add_argument(
"--score_threshold",
type=float,
default=0.5,
help="Threshold for binary segmentation",
)
parser.add_argument(
"--alpha",
type=float,
default=0.5,
help="Transparency of the mask overlay (0.0 to 1.0)",
)
# Model-specific arguments
parser.add_argument("--feat_chans", type=int, default=256)
parser.add_argument("--image_enc_use_fc", action="store_true")
parser.add_argument("--transformer_depth", type=int, default=6)
parser.add_argument("--transformer_nheads", type=int, default=8)
parser.add_argument("--transformer_mlp_dim", type=int, default=2048)
parser.add_argument("--transformer_mask_dim", type=int, default=256)
parser.add_argument("--transformer_fusion_layer_depth", type=int, default=1)
parser.add_argument("--transformer_num_queries", type=int, default=200)
parser.add_argument("--transformer_pre_norm", action="store_true", default=True)
parser.add_argument("--pt_model", type=str, default="dinov2")
parser.add_argument("--dinov2-size", type=str, default="vit_large")
parser.add_argument(
"--dinov2-weights", type=str, default="checkpoints/dinov2_vitl14_pretrain.pth"
)

args = parser.parse_args()

# Validate input
assert len(args.support_img_paths) == len(args.support_mask_paths), (
"Number of support images must match number of support masks"
)

# Set up device
device = torch.device(args.device if torch.cuda.is_available() else "cpu")

# Load the model
model = build_model(args)
state_dict = torch.load(args.model_weights, map_location="cpu")
model.load_state_dict(state_dict, strict=False)
model.to(device)
model.eval()

# Load support and query data
support_imgs, support_masks, query_imgs = parse_data(
args.support_img_paths,
args.support_mask_paths,
args.query_img_path,
)
ref_list = load_support_data(support_imgs, support_masks, args, device)

# Create output directory
os.makedirs(args.output_dir, exist_ok=True)

# Process query images
print("Running inference and generating overlaid images...")
for query_img_path in tqdm(query_imgs, total=len(query_imgs)):
# Preprocess query image and mask
tar_dict = preprocess_image_and_mask(
query_img_path, None, args, device, class_id=0, instance_id=1
)

# Prepare input data
data = [
{"ref_dict": ref, "tar_dict": tar_dict if i == 0 else None}
for i, ref in enumerate(ref_list)
]

# Perform inference
with torch.no_grad():
output = model(data)
pred = output["sem_seg"].squeeze() # Shape: (H, W)
pred = pred > args.score_threshold # Binary mask
pred = pred.float().cpu().numpy()
pred_mask = (pred * 255).astype(np.uint8)

# Save the predicted mask
pred_mask = Image.fromarray(pred_mask)
pred_mask.save(
os.path.join(
args.output_dir, f"{os.path.basename(query_img_path)[:-4]}_mask.png"
)
)

# Save the overlaid image
overlaid_img = overlay_mask_on_image(tar_dict["original_img"], pred_mask)
overlaid_img.save(
os.path.join(
args.output_dir, f"{os.path.basename(query_img_path)[:-4]}_overlay.png"
)
)

print(f"Overlaid images saved to {args.output_dir}")


if __name__ == "__main__":
main()
37 changes: 24 additions & 13 deletions inference_fss/model/transformer_decoder/mformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -499,25 +499,36 @@ def forward_prediction_heads(self, output, mask_features, attn_mask_target_size)

return outputs_class, outputs_mask, attn_mask, mask_embed

def apply_gaussian_kernel(self, corr, spatial_side, sigma=10):
bsz, side1, side2 = corr.size()
def apply_gaussian_kernel(self, corr, spatial_height, spatial_width, sigma=10):
bsz, nm, _ = corr.size()

# Get max correlation index for each query
center = corr.max(dim=2)[1]
center_y = center // spatial_side
center_x = center % spatial_side

# Compute y, x coordinates from the flattened index
center_y = center // spatial_width # Row index (height)
center_x = center % spatial_width # Column index (width)

x = torch.arange(0, spatial_side).float().to(corr.device)
y = torch.arange(0, spatial_side).float().to(corr.device)
# Create coordinate grid
x = torch.arange(0, spatial_width).float().to(corr.device)
y = torch.arange(0, spatial_height).float().to(corr.device)

y = y.view(1, 1, spatial_side).repeat(bsz, center_y.size(1), 1) - center_y.unsqueeze(2)
x = x.view(1, 1, spatial_side).repeat(bsz, center_x.size(1), 1) - center_x.unsqueeze(2)
# Compute y-distance and x-distance from the center
y = y.view(1, 1, spatial_height).repeat(bsz, nm, 1) - center_y.unsqueeze(2)
x = x.view(1, 1, spatial_width).repeat(bsz, nm, 1) - center_x.unsqueeze(2)

y = y.unsqueeze(3).repeat(1, 1, 1, spatial_side)
x = x.unsqueeze(2).repeat(1, 1, spatial_side, 1)
# Expand dimensions for proper broadcasting
y = y.unsqueeze(3).repeat(1, 1, 1, spatial_width) # (B, nm, H, W)
x = x.unsqueeze(2).repeat(1, 1, spatial_height, 1) # (B, nm, H, W)

# Compute Gaussian kernel with different H and W
gauss_kernel = torch.exp(-(x.pow(2) + y.pow(2)) / (2 * sigma ** 2))
filtered_corr = gauss_kernel * corr.view(bsz, -1, spatial_side, spatial_side)
filtered_corr = filtered_corr.view(bsz, side1, side2)

# Reshape correlation map to match (B, nm, H, W)
filtered_corr = gauss_kernel * corr.view(bsz, nm, spatial_height, spatial_width)

# Restore shape back to (B, nm, H*W)
filtered_corr = filtered_corr.view(bsz, nm, spatial_height * spatial_width)

return filtered_corr

Expand Down Expand Up @@ -584,7 +595,7 @@ def forward_per_image(
id_query_feat_norm = F.normalize(id_query_feat, dim=-1, p=2)
corr_matrix = torch.einsum('nac,nbc->nab', id_query_feat_norm, image_feat_norm) # 1, nm, HW

id_corr_matrix = self.apply_gaussian_kernel(corr_matrix, h)
id_corr_matrix = self.apply_gaussian_kernel(corr_matrix, h, w)
id_dist = torch.softmax(id_corr_matrix * self.temp, dim=-1)
id_embed = torch.einsum('nab,nbc->nac', id_dist, image_pe)
output_id = id_query_feat # bs, nm, c
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ torch==2.0.1
torchvision==0.15.2
xformers==0.0.21
opencv-python==4.8.0.76
timm==0.9.17
timm==1.0.3
omegaconf==2.3.0
numpy==1.26.1
tqdm==4.66.1
Expand Down
0