8000 RSCFed/datasets.py at main · xmed-lab/RSCFed · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

Latest commit

 

History

History
237 lines (180 loc) · 7.27 KB

datasets.py

File metadata and controls

237 lines (180 loc) · 7.27 KB
import torch.utils.data as data
from PIL import Image
import numpy as np
import torchvision
from torchvision.datasets import MNIST, EMNIST, CIFAR10, CIFAR100, SVHN, FashionMNIST, ImageFolder, DatasetFolder, utils
import os
import os.path
import logging
logging.basicConfig()
logger = logging.getLogger()
logger.setLevel(logging.INFO)
IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp')
def mkdirs(dirpath):
try:
os.makedirs(dirpath)
except Exception as _:
pass
class CIFAR10_truncated(data.Dataset):
def __init__(self, root, dataidxs=None, train=True, transform=None, target_transform=None, download=False,is_labeled=False):
self.root = root
self.dataidxs = dataidxs
self.train = train
self.transform = transform
self.target_transform = target_transform
self.download = download
self.is_labeled=is_labeled
self.data, self.target = self.__build_truncated_dataset__()
def __build_truncated_dataset__(self):
cifar_dataobj = CIFAR10(self.root, self.train, self.transform, self.target_transform, self.download)
if torchvision.__version__ == '0.2.1':
if self.train:
data, target = cifar_dataobj.train_data, np.array(cifar_dataobj.train_labels)
else:
data, target = cifar_dataobj.test_data, np.array(cifar_dataobj.test_labels)
else:
data = cifar_dataobj.data
target = np.array(cifar_dataobj.targets)
if self.dataidxs is not None:
data = data[self.dataidxs]
target = target[self.dataidxs]
return data, target
def truncate_channel(self, index):
for i in range(index.shape[0]):
gs_index = index[i]
self.data[gs_index, :, :, 1] = 0.0
self.data[gs_index, :, :, 2] = 0.0
def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
tuple: (image, target) where target is index of the target class.
"""
img, target = self.data[index], self.target[index]
# img = Image.fromarray(img)
# print("cifar10 img:", img)
# print("cifar10 target:", target)
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
target = self.target_transform(target)
return img, target
def __len__(self):
return len(self.data)
class CIFAR100_truncated(data.Dataset):
def __init__(self, root, dataidxs=None, train=True, transform=None, target_transform=None, download=False):
self.root = root
self.dataidxs = dataidxs
self.train = train
self.transform = transform
self.target_transform = target_transform
self.download = download
self.data, self.target = self.__build_truncated_dataset__()
def __build_truncated_dataset__(self):
cifar_dataobj = CIFAR100(self.root, self.train, self.transform, self.target_transform, self.download)
if torchvision.__version__ == '0.2.1':
if self.train:
data, target = cifar_dataobj.train_data, np.array(cifar_dataobj.train_labels)
else:
data, target = cifar_dataobj.test_data, np.array(cifar_dataobj.test_labels)
else:
data = cifar_dataobj.data
target = np.array(cifar_dataobj.targets)
if self.dataidxs is not None:
data = data[self.dataidxs]
target = target[self.dataidxs]
return data, target
def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
tuple: (image, target) where target is index of the target class.
"""
img, target = self.data[index], self.target[index]
img = Image.fromarray(img)
# print("cifar10 img:", img)
# print("cifar10 target:", target)
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
target = self.target_transform(target)
return img, target
def __len__(self):
return len(self.data)
class SVHN_truncated(data.Dataset):
def __init__(self, root, dataidxs=None, split='train', transform=None, target_transform=None, download=False,is_labeled=False):
self.root = root
self.dataidxs = dataidxs
self.split = split
self.transform = transform
self.target_transform = target_transform
self.download = download
self.is_labeled=is_labeled
self.data, self.target = self.__build_truncated_dataset__()
def __build_truncated_dataset__(self):
SVHN_dataobj = SVHN(self.root, self.split, self.transform, self.target_transform, self.download)
if torchvision.__version__ == '0.2.1':
if self.split:
data, target = SVHN_dataobj.train_data, np.array(SVHN_dataobj.train_labels)
else:
data, target = SVHN_dataobj.test_data, np.array(SVHN_dataobj.test_labels)
else:
data = SVHN_dataobj.data
target = np.array(SVHN_dataobj.labels)
if self.dataidxs is not None:
data = data[self.dataidxs]
target = target[self.dataidxs]
return data, target
def truncate_channel(self, index):
for i in range(index.shape[0]):
gs_index = index[i]
self.data[gs_index, :, :, 1] = 0.0
self.data[gs_index, :, :, 2] = 0.0
def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
tuple: (image, target) where target is index of the target class.
"""
img, target = self.data[index], self.target[index]
# img = Image.fromarray(img)
# print("cifar10 img:", img)
# print("cifar10 target:", target)
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
target = self.target_transform(target)
return img, target
def __len__(self):
return len(self.data)
class ImageFolder_custom(DatasetFolder):
def __init__(self, root, dataidxs=None, train=True, transform=None, target_transform=None):
self.root = root
self.dataidxs = dataidxs
self.train = train
self.transform = transform
self.target_transform = target_transform
imagefolder_obj = ImageFolder(self.root, self.transform, self.target_transform)
self.loader = imagefolder_obj.loader
if self.dataidxs is not None:
self.samples = np.array(imagefolder_obj.samples)[self.dataidxs]
else:
self.samples = np.array(imagefolder_obj.samples)
def __getitem__(self, index):
path = self.samples[index][0]
target = self.samples[index][1]
target = int(target)
sample = self.loader(path)
if self.transform is not None:
sample = self.transform(sample)
if self.target_transform is not None:
target = self.target_transform(target)
return sample, target
def __len__(self):
if self.dataidxs is None:
return len(self.samples)
else:
return len(self.dataidxs)
0