8000 Classification code with modification of train.py and datasets/brains18.py · Issue #79 · Tencent/MedicalNet · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content
Classification code with modification of train.py and datasets/brains18.py #79
Open
@loopnownow

Description

@loopnownow

I'd like to share the code of classification with modification of train.py and datasets/brains18.py
The code is based on JasperHG90's commented at #58
To deal with the issue of Classification, there are three main problems should be resolved.
First, the format of the class labels.
Second, how to add the labels of the class into the codes.
Third, how to define an appropriate LOSS.

For the fist one, I simply add the class labels in the training data lists under data/train after the address of the masks
Such as:
MRBrainS18/images/75.nii.gz MRBrainS18/labels/75.nii.gz 0
MRBrainS18/images/14.nii.gz MRBrainS18/labels/14.nii.gz 0
MRBrainS18/images/148.nii.gz MRBrainS18/labels/148.nii.gz 0
MRBrainS18/images/4.nii.gz MRBrainS18/labels/4.nii.gz 0
MRBrainS18/images/5.nii.gz MRBrainS18/labels/5.nii.gz 0
MRBrainS18/images/7.nii.gz MRBrainS18/labels/7.nii.gz 1
MRBrainS18/images/71.nii.gz MRBrainS18/labels/71.nii.gz 1
MRBrainS18/images/72.nii.gz MRBrainS18/labels/72.nii.gz 1
MRBrainS18/images/73.nii.gz MRBrainS18/labels/73.nii.gz 1
MRBrainS18/images/74.nii.gz MRBrainS18/labels/74.nii.gz 1

For the second one, I modified the file of datasets/brains18.py. The modified code were followed by #####
The modified code were:
import torch #####
line 43, label_name = os.path.join(self.root_dir, ith_info[1])
class_array = np.zeros((1)) ######
class_array[0] = ith_info[2] #####
class_array = torch.tensor(class_array, dtype=torch.float32) ######

line 59, return img_array, mask_array, class_array #####

For the third one, I modified the file of train.py. The modified code were followed by #####
The modified code were:
from models import resnet #####
#the code below is from JasperHG90's commented at #58, any modification was followed by ##### (notification that I used model of resnet50 rather than resnet 34)
class MedicalNet(nn.Module):

  def __init__(self, path_to_weights, device):
    super(MedicalNet, self).__init__()
    self.model = resnet.resnet50(sample_input_D=14, sample_input_H=112, sample_input_W=112, num_seg_classes=2)
    self.model.conv_seg = nn.Sequential(
        nn.AdaptiveMaxPool3d(output_size=(1, 1, 1)),
        nn.Flatten(start_dim=1),
        nn.Dropout(0.1)
    )
    net_dict = self.model.state_dict()
    pretrained_weights = torch.load(path_to_weights, map_location=torch.device(device))
    pretrain_dict = {
        k.replace("module.", ""): v for k, v in pretrained_weights['state_dict'].items() if k.replace("module.", "") in net_dict.keys()
      }
    net_dict.update(pretrain_dict)
    self.model.load_state_dict(net_dict)
    self.fc = nn.Linear(2048, 1)

  def forward(self, x):
    features = self.model(x)
    return torch.sigmoid_(self.fc(features)) #####

line 24, loss_seg = nn.BCELoss() ##### loss_seg = nn.CrossEntropyLoss(ignore_index=-1)
line 43, volumes, label_masks, class_array = batch_data ##### volumes, label_masks = batch_data
line 47, class_array = class_array.cuda() #####
line 66, loss_value_seg = loss_seg(out_masks, class_array) ##### loss_value_seg = loss_seg(out_masks, new_label_masks)
line 67, loss = loss_value_seg
line 68, loss.requires_grad_(True) #####

line 118, model = MedicalNet(path_to_weights="pretrain/resnet_50.pth", device='cuda') #####model, parameters = generate_model(sets)
line 119, model = MedicalNet(path_to_weights="pretrain/resnet_50.pth", device='cuda') #####print (model)
line 120, model.cuda() ##### # optimizer

for param_name, param in model.named_parameters():
    if param_name.startswith("conv_seg"):
        param.requires_grad = True
    else:
        param.requires_grad = False

line 128, optimizer = torch.optim.SGD(model.parameters(), lr=1e-5,momentum=0.9, weight_decay=1e-3) #####optimizer = torch.optim.SGD(params, momentum=0.9, weight_decay=1e-3)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions

      0