Source code for steganogan.models

# -*- coding: utf-8 -*-
import gc
import inspect
import json
import os
from collections import Counter

import imageio
import torch
from imageio import imread, imwrite
from torch.nn.functional import binary_cross_entropy_with_logits, mse_loss
from torch.optim import Adam
from tqdm import tqdm

from steganogan.utils import bits_to_bytearray, bytearray_to_text, ssim, text_to_bits

DEFAULT_PATH = os.path.join(
    os.path.dirname(os.path.abspath(__file__)),
    'train')

METRIC_FIELDS = [
    'val.encoder_mse',
    'val.decoder_loss',
    'val.decoder_acc',
    'val.cover_score',
    'val.generated_score',
    'val.ssim',
    'val.psnr',
    'val.bpp',
    'train.encoder_mse',
    'train.decoder_loss',
    'train.decoder_acc',
    'train.cover_score',
    'train.generated_score',
]


[docs]class SteganoGAN(object): def _get_instance(self, class_or_instance, kwargs): """Returns an instance of the class""" if not inspect.isclass(class_or_instance): return class_or_instance argspec = inspect.getfullargspec(class_or_instance.__init__).args argspec.remove('self') init_args = {arg: kwargs[arg] for arg in argspec} return class_or_instance(**init_args)
[docs] def set_device(self, cuda=True): """Sets the torch device depending on whether cuda is avaiable or not.""" if cuda and torch.cuda.is_available(): self.cuda = True self.device = torch.device('cuda') else: self.cuda = False self.device = torch.device('cpu') if self.verbose: if not cuda: print('Using CPU device') elif not self.cuda: print('CUDA is not available. Defaulting to CPU device') else: print('Using CUDA device') self.encoder.to(self.device) self.decoder.to(self.device) self.critic.to(self.device)
def __init__(self, data_depth, encoder, decoder, critic, cuda=False, verbose=False, log_dir=None, **kwargs): self.verbose = verbose self.data_depth = data_depth kwargs['data_depth'] = data_depth self.encoder = self._get_instance(encoder, kwargs) self.decoder = self._get_instance(decoder, kwargs) self.critic = self._get_instance(critic, kwargs) self.set_device(cuda) self.critic_optimizer = None self.decoder_optimizer = None # Misc self.fit_metrics = None self.history = list() self.log_dir = log_dir if log_dir: os.makedirs(self.log_dir, exist_ok=True) self.samples_path = os.path.join(self.log_dir, 'samples') os.makedirs(self.samples_path, exist_ok=True) def _random_data(self, cover): """Generate random data ready to be hidden inside the cover image. Args: cover (image): Image to use as cover. Returns: generated (image): Image generated with the encoded message. """ N, _, H, W = cover.size() return torch.zeros((N, self.data_depth, H, W), device=self.device).random_(0, 2) def _encode_decode(self, cover, quantize=False): """Encode random data and then decode it. Args: cover (image): Image to use as cover. quantize (bool): whether to quantize the generated image or not. Returns: generated (image): Image generated with the encoded message. payload (bytes): Random data that has been encoded in the image. decoded (bytes): Data decoded from the generated image. """ payload = self._random_data(cover) generated = self.encoder(cover, payload) if quantize: generated = (255.0 * (generated + 1.0) / 2.0).long() generated = 2.0 * generated.float() / 255.0 - 1.0 decoded = self.decoder(generated) return generated, payload, decoded def _critic(self, image): """Evaluate the image using the critic""" return torch.mean(self.critic(image)) def _get_optimizers(self): _dec_list = list(self.decoder.parameters()) + list(self.encoder.parameters()) critic_optimizer = Adam(self.critic.parameters(), lr=1e-4) decoder_optimizer = Adam(_dec_list, lr=1e-4) return critic_optimizer, decoder_optimizer def _fit_critic(self, train, metrics): """Critic process""" for cover, _ in tqdm(train, disable=not self.verbose): gc.collect() cover = cover.to(self.device) payload = self._random_data(cover) generated = self.encoder(cover, payload) cover_score = self._critic(cover) generated_score = self._critic(generated) self.critic_optimizer.zero_grad() (cover_score - generated_score).backward(retain_graph=False) self.critic_optimizer.step() for p in self.critic.parameters(): p.data.clamp_(-0.1, 0.1) metrics['train.cover_score'].append(cover_score.item()) metrics['train.generated_score'].append(generated_score.item()) def _fit_coders(self, train, metrics): """Fit the encoder and the decoder on the train images.""" for cover, _ in tqdm(train, disable=not self.verbose): gc.collect() cover = cover.to(self.device) generated, payload, decoded = self._encode_decode(cover) encoder_mse, decoder_loss, decoder_acc = self._coding_scores( cover, generated, payload, decoded) generated_score = self._critic(generated) self.decoder_optimizer.zero_grad() (100.0 * encoder_mse + decoder_loss + generated_score).backward() self.decoder_optimizer.step() metrics['train.encoder_mse'].append(encoder_mse.item()) metrics['train.decoder_loss'].append(decoder_loss.item()) metrics['train.decoder_acc'].append(decoder_acc.item()) def _coding_scores(self, cover, generated, payload, decoded): encoder_mse = mse_loss(generated, cover) decoder_loss = binary_cross_entropy_with_logits(decoded, payload) decoder_acc = (decoded >= 0.0).eq(payload >= 0.5).sum().float() / payload.numel() return encoder_mse, decoder_loss, decoder_acc def _validate(self, validate, metrics): """Validation process""" for cover, _ in tqdm(validate, disable=not self.verbose): gc.collect() cover = cover.to(self.device) generated, payload, decoded = self._encode_decode(cover, quantize=True) encoder_mse, decoder_loss, decoder_acc = self._coding_scores( cover, generated, payload, decoded) generated_score = self._critic(generated) cover_score = self._critic(cover) metrics['val.encoder_mse'].append(encoder_mse.item()) metrics['val.decoder_loss'].append(decoder_loss.item()) metrics['val.decoder_acc'].append(decoder_acc.item()) metrics['val.cover_score'].append(cover_score.item()) metrics['val.generated_score'].append(generated_score.item()) metrics['val.ssim'].append(ssim(cover, generated).item()) metrics['val.psnr'].append(10 * torch.log10(4 / encoder_mse).item()) metrics['val.bpp'].append(self.data_depth * (2 * decoder_acc.item() - 1)) def _generate_samples(self, samples_path, cover, epoch): cover = cover.to(self.device) generated, payload, decoded = self._encode_decode(cover) samples = generated.size(0) for sample in range(samples): cover_path = os.path.join(samples_path, '{}.cover.png'.format(sample)) sample_name = '{}.generated-{:2d}.png'.format(sample, epoch) sample_path = os.path.join(samples_path, sample_name) image = (cover[sample].permute(1, 2, 0).detach().cpu().numpy() + 1.0) / 2.0 imageio.imwrite(cover_path, (255.0 * image).astype('uint8')) sampled = generated[sample].clamp(-1.0, 1.0).permute(1, 2, 0) sampled = sampled.detach().cpu().numpy() + 1.0 image = sampled / 2.0 imageio.imwrite(sample_path, (255.0 * image).astype('uint8'))
[docs] def fit(self, train, validate, epochs=5): """Train a new model with the given ImageLoader class.""" if self.critic_optimizer is None: self.critic_optimizer, self.decoder_optimizer = self._get_optimizers() self.epochs = 0 if self.log_dir: sample_cover = next(iter(validate))[0] # Start training total = self.epochs + epochs for epoch in range(1, epochs + 1): # Count how many epochs we have trained for this steganogan self.epochs += 1 metrics = {field: list() for field in METRIC_FIELDS} if self.verbose: print('Epoch {}/{}'.format(self.epochs, total)) self._fit_critic(train, metrics) self._fit_coders(train, metrics) self._validate(validate, metrics) self.fit_metrics = {k: sum(v) / len(v) for k, v in metrics.items()} self.fit_metrics['epoch'] = epoch if self.log_dir: self.history.append(self.fit_metrics) metrics_path = os.path.join(self.log_dir, 'metrics.log') with open(metrics_path, 'w') as metrics_file: json.dump(self.history, metrics_file, indent=4) save_name = '{}.bpp-{:03f}.p'.format( self.epochs, self.fit_metrics['val.bpp']) self.save(os.path.join(self.log_dir, save_name)) self._generate_samples(self.samples_path, sample_cover, epoch) # Empty cuda cache (this may help for memory leaks) if self.cuda: torch.cuda.empty_cache() gc.collect()
def _make_payload(self, width, height, depth, text): """ This takes a piece of text and encodes it into a bit vector. It then fills a matrix of size (width, height) with copies of the bit vector. """ message = text_to_bits(text) + [0] * 32 payload = message while len(payload) < width * height * depth: payload += message payload = payload[:width * height * depth] return torch.FloatTensor(payload).view(1, depth, height, width)
[docs] def encode(self, cover, output, text): """Encode an image. Args: cover (str): Path to the image to be used as cover. output (str): Path where the generated image will be saved. text (str): Message to hide inside the image. """ cover = imread(cover, pilmode='RGB') / 127.5 - 1.0 cover = torch.FloatTensor(cover).permute(2, 1, 0).unsqueeze(0) cover_size = cover.size() # _, _, height, width = cover.size() payload = self._make_payload(cover_size[3], cover_size[2], self.data_depth, text) cover = cover.to(self.device) payload = payload.to(self.device) generated = self.encoder(cover, payload)[0].clamp(-1.0, 1.0) generated = (generated.permute(2, 1, 0).detach().cpu().numpy() + 1.0) * 127.5 imwrite(output, generated.astype('uint8')) if self.verbose: print('Encoding completed.')
[docs] def decode(self, image): if not os.path.exists(image): raise ValueError('Unable to read %s.' % image) # extract a bit vector image = imread(image, pilmode='RGB') / 255.0 image = torch.FloatTensor(image).permute(2, 1, 0).unsqueeze(0) image = image.to(self.device) image = self.decoder(image).view(-1) > 0 # split and decode messages candidates = Counter() bits = image.data.cpu().numpy().tolist() for candidate in bits_to_bytearray(bits).split(b'\x00\x00\x00\x00'): candidate = bytearray_to_text(bytearray(candidate)) if candidate: candidates[candidate] += 1 # choose most common message if len(candidates) == 0: raise ValueError('Failed to find message.') candidate, count = candidates.most_common(1)[0] return candidate
[docs] def save(self, path): """Save the fitted model in the given path. Raises an exception if there is no model.""" torch.save(self, path)
[docs] @classmethod def load(cls, architecture=None, path=None, cuda=True, verbose=False): """Loads an instance of SteganoGAN for the given architecture (default pretrained models) or loads a pretrained model from a given path. Args: architecture(str): Name of a pretrained model to be loaded from the default models. path(str): Path to custom pretrained model. *Architecture must be None. cuda(bool): Force loaded model to use cuda (if available). verbose(bool): Force loaded model to use or not verbose. """ if architecture and not path: model_name = '{}.steg'.format(architecture) pretrained_path = os.path.join(os.path.dirname(__file__), 'pretrained') path = os.path.join(pretrained_path, model_name) elif (architecture is None and path is None) or (architecture and path): raise ValueError( 'Please provide either an architecture or a path to pretrained model.') steganogan = torch.load(path, map_location='cpu') steganogan.verbose = verbose steganogan.encoder.upgrade_legacy() steganogan.decoder.upgrade_legacy() steganogan.critic.upgrade_legacy() steganogan.set_device(cuda) return steganogan