文章目录
-
- 相关链接
- Dataset
-
- VisionDataset
- DatsetFolder
- ImageFolder
- torchvision.transforms
- Pytorch Lightning
-
- LightningDataModule
对于图像数据集来说,首先是在Dataset类对数据集进行定义,一般来说不定义transform,则数据为PIL Image,PIL格式到tensor的转换也是transforms变换的一种,所以定Dataset+transforms+Dataloader,最后在训练部分to(device)来得到模型的输入。
相关链接
torchvision.datasets的三个基础类
torchvision.datasets
torch.utils.data.Dataset
Pillow(PIL Fork) Image模块
Dataset
Dataset是数据集在pytorch中的化身,需要重写__ getitem__ 和 __ len__。__ __ getitem__ 通过传入的索引加载指定路径的数据,路径常常是一个列表,如很多张图片组成的数据集,需要在初始化时定义函数得到路径列表,或者在外部定义,总之要得到一个路径List。也需要在其中定义或调用具体读取的代码,如PIL库的Image.open()来读取图片,或Image.fromarray()来创建图片,也就是需要知道数据在哪里和怎么读取。
└─Dataset
└─VisionDataset
└─DatasetFolder
└─ImageFolder
Dataset是torch.utils.data中的类,是数据集的基础类
VisionDataset
VisionDataset是torchvision.datasets.vision中的类,是torchvision类数据集的基础类,相比于原始的Dataset类,提供了transform,transforms,target_transform数据变换的接口
DatasetFolder,ImageFolder都来自torchvision.datasets.folder ,既然叫做folder,实际上已经有了完整的数据集功能,可以按照默认的目录结构读取数据。DatasetFolder还需要定义loader以读取特定类型的数据,和is_valid_file或者extensions,is_valid_file和extensions不能同时定义,但必须有一个定义,如果定义了有效后缀名,会自动通过后缀来判断文件有效性。而ImageFolder更进一步,默认使用读取图像数据的loader读取,还默认定义了图像后缀名。从Dataset到ImageFolder构成了不同层次的封装,完成度越高,灵活性越低,可以根据自己的需要选择。
除了在__ getitem__ 中通过得到的路径列表来读取数据,对于不同格式的数据也有不同的做法,如torchvision中内置cifar数据集,会直接从原始数据中以矩阵的形式读取, 因此 __ getitem__ 会从矩阵中创建Image对象。总而言之,一般来讲对于图片数据集来说,__ get __返回的都是PIL Image对象,不管是从路径列表中读取,还是整个以矩阵形式读取,如果不定义transform,最后在Dataset阶段都是PIL对象。
DatsetFolder
默认的排列结构如下,每一个文件夹表示一类,下面是这一类的样本
directory/
├── class_x
│ ├── xxx.ext
│ ├── xxy.ext
│ └── ...
│ └── xxz.ext
└── class_y
├── 123.ext
├── nsdf3.ext
└── ...
└── asd932_.ext
用文件夹来区分不同的类别。比较重要的有两类操作,find_class函数得到类别名和类别序号。make_dataset得到路径列表。
默认的findclass函数
文件夹名是类名。
def find_classes(directory: str) -> Tuple[List[str], Dict[str, int]]:
"""Finds the class folders in a dataset.
See :class:`DatasetFolder` for details.
"""
classes = sorted(entry.name for entry in os.scandir(directory) if entry.is_dir())
if not classes:
raise FileNotFoundError(f"Couldn't find any class folder in {directory}.")
class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}
return classes, class_to_idx
默认的make_dataset
得到instance列表,表示文件的路径列表。
基本上很大一部分是在定义有效性判断相关,主要部分是一个双层for循环,因为类名定义为文件夹名,所以会遍历各个类的文件夹,会将遍历到的有效文件的路径加入instance,遍历过的非空类添加到available_classe。
def make_dataset(
directory: Union[str, Path],
clas服务器托管网s_to_idx: Optional[Dict[str, int]] = None,
extensions: Optional[Union[str, Tuple[str, ...]]] = None,
is_valid_file: Optional[Callable[[str], bool]] = None,
allow_empty: bool = False,
) -> List[Tuple[str, int]]:
directory = os.path.expanduser(directory)
if class_to_idx is None:
_, class_to_idx = find_classes(directory)
elif not class_to_idx:
raise ValueError("'class_to_index' must have at least one entry to collect any samples.")
both_none = extensions is None and is_valid_file is None
both_something = extensions is not None and is_valid_file is not None
if both_none or both_something:
raise ValueError("Both extensions and is_valid_file cannot be None or not None at the same time")
if extensions is not None:
def is_valid_file(x: str) -> bool:
return has_file_allowed_extension(x, extensions) # type: ignore[arg-type]
is_valid_file = cast(Callable[[str], bool], is_valid_file)
instances = []
available_classes = set()
for target_class in sorted(class_to_idx.keys()):
class_index = class_to_idx[target_class]
target_dir = os.path.join(directory, target_class)
if not os.path.isdir(target_dir):
continue
for root, _, fnames in sorted(os.walk(target_dir, followlinks=True)):
for fname in sorted(fnames):
path = os.path.join(root, fname)
if is_valid_file(path):
item = path, class_index
instances.append(item)
if target_class not in available_classes:
available_classes.add(target_class)
empty_classes = set(clas服务器托管网s_to_idx.keys()) - available_classes
if empty_classes and not allow_empty:
msg = f"Found no valid file for the classes {', '.join(sorted(empty_classes))}. "
if extensions is not None:
msg += f"Supported extensions are: {extensions if isinstance(extensions, str) else ', '.join(extensions)}"
raise FileNotFoundError(msg)
return instances
ImageFolder
root/dog/xxy.png
root/dog/[...]/xxz.png
root/cat/123.png
root/cat/nsdf3.png
ImageFolder如名字所示,如果数据集是这种文件夹排列,而且是图像文件,又没有需要特殊定义的部分 ,可以直接实例化一个ImageFolder,而不需要重写任何部分 ,实例化一个数据集只需要传入数据集路径和tansform变换。
train_ds, train_valid_ds = [torchvision.datasets.ImageFolder(
os.path.join(data_dir, 'train_valid_test', folder),
transform=transform_train) for folder in ['train', 'train_valid']]
torchvision.transforms
一般会在数据集实例化时,从外部传入,通常自定义的Transforms序列包含ToTensor,可以将上一阶段的PIL Image转换为Tensor,而下一次变化要到训练时的to(device),这样数据最终输入完成,也可以在Dataset类中写入默认的transform。
通过torchvision.get_image_backend得到torchvision现在的后端默认为PILtorchvision.set_image_backend(backend)指定用来读取图片的包,可选accimage
Loader将数据读取为PIL对象,一般数据集定义不在数据集内部定义默认的transform图像变换,而是在外部定义一个transform序列,通常倒数第二个是torchvision.transforms.ToTensor()操作,会将一个PIL Image或者一个ndarray转换为tensor并缩放到[0.0, 1.0]。因此接下来会通过transforms.Normalize进行归一化。
PILToTensor会把PIL Image转化为tensor,但是不会进行缩放,
(
H
W
C
)
→
(
C
H
W
)
(Htimes Wtimes C)rightarrow (Ctimes H times W)
(HWC)→(CHW)
ToTensor会把PIL Image或者ndarray转换成tensor而且会进行缩放。
(
H
W
C
)
→
(
C
H
W
)
(Htimes Wtimes C)rightarrow (Ctimes H times W)
(HWC)→(CHW) 在规定的模式如RGBA,RGB,YCbCr或者dtype = np.uint8情况下,别的情况下不缩放。
Normalize只支持tensor,其他大部分操作也支持PIL,所以在ToTensor之后最后进行Normalize
data_transform = {
"train": transforms.Compose([transforms.RandomResizedCrop(img_size),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),
"val": transforms.Compose([transforms.Resize(int(img_size * 1.143)),
transforms.CenterCrop(img_size),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])}
Pytorch Lightning
Pytorch Lightning是Pytorch中的kersas,简称pl
LightningDataModule
Pytorch Lightning继承LightningDataModule定义数据集,pl中的Dataset和Dataloader是高度耦合的。
import lightning.pytorch as L
import torch.utils.data as data
from pytorch_lightning.demos.boring_classes import RandomDataset
class MyDataModule(L.LightningDataModule):
def prepare_data(self):
# download, IO, etc. Useful with shared filesystems
# only called on 1 GPU/TPU in distributed
...
def setup(self, stage):
# make assignments here (val/train/test split)
# called on every process in DDP
dataset = RandomDataset(1, 100)
self.train, self.val, self.test = data.random_split(
dataset, [80, 10, 10], generator=torch.Generator().manual_seed(42)
)
def train_dataloader(self):
return data.DataLoader(self.train)
def val_dataloader(self):
return data.DataLoader(self.val)
def test_dataloader(self):
return data.DataLoader(self.test)
def teardown(self):
# clean up state after the trainer stops, delete files...
# called on every process in DDP
...
服务器托管,北京服务器托管,服务器租用 http://www.fwqtg.net
前言: 人生重开模拟器是前段时间非常火的一个小游戏,接下来我们将一起学习使用c语言写一个简易版的人生重开模拟器。 网页版游戏: 人生重开模拟器 (ytecn.com) 1.实现一个简化版的人生重开模拟器 (1) 游戏开始的时候,设定初始属性:颜值,体质,智力,…