Source code for steganogan.loader

# -*- coding: utf-8 -*-

import numpy as np
import torch
import torchvision
from torchvision import transforms

_DEFAULT_MU = [.5, .5, .5]
_DEFAULT_SIGMA = [.5, .5, .5]

DEFAULT_TRANSFORM = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomCrop(360, pad_if_needed=True),
    transforms.ToTensor(),
    transforms.Normalize(_DEFAULT_MU, _DEFAULT_SIGMA),
])


[docs]class ImageFolder(torchvision.datasets.ImageFolder): def __init__(self, path, transform, limit=np.inf): super().__init__(path, transform=transform) self.limit = limit def __len__(self): length = super().__len__() return min(length, self.limit)
[docs]class DataLoader(torch.utils.data.DataLoader): def __init__(self, path, transform=None, limit=np.inf, shuffle=True, num_workers=8, batch_size=4, *args, **kwargs): if transform is None: transform = DEFAULT_TRANSFORM super().__init__( ImageFolder(path, transform, limit), batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, *args, **kwargs )