import numpy as np
from fibber import log
from fibber.datasets import builtin_datasets
from fibber.datasets.dataset_utils import get_dataset, verify_dataset
from fibber.metrics import MetricBundle
from fibber.paraphrase_strategies import (
ASRSStrategy, IdentityStrategy, RandomStrategy, TextAttackStrategy)
logger = log.setup_custom_logger(__name__)
[docs]class Fibber(object):
"""Fibber is a unified interface for paraphrase strategies."""
def __init__(self, arg_dict, dataset_name, strategy_name, field="text0",
trainset=None, testset=None, output_dir=".", bert_clf_steps=5000):
"""Initialize
Args:
arg_dict (dict): a dict of hyper parameters for the MetricBundle and strategy.
dataset_name (str): the name of the dataset.
strategy_name (str): the strategy name.
field (str):
trainset (dict): fibber dataset.
testset (dict): fibber testset.
output_dir (str): directory to cache the strategy.
"""
super(Fibber, self).__init__()
self._field = field
# setup dataset
if dataset_name in builtin_datasets:
if trainset is not None or testset is not None:
logger.error(("dataset name %d conflict with builtin dataset. "
"set trainset and testset to None.") % dataset_name)
raise RuntimeError
trainset, testset = get_dataset(dataset_name)
else:
verify_dataset(trainset)
verify_dataset(testset)
self._metric_bundle = MetricBundle(
field=field,
enable_transformer_classifier=True,
enable_bert_perplexity=True,
enable_gpt2_perplexity=False,
enable_glove_similarity=False,
bert_ppl_gpu_id=arg_dict["bert_ppl_gpu_id"],
use_gpu_id=arg_dict["use_gpu_id"],
transformer_gpu_id=arg_dict["transformer_clf_gpu_id"],
dataset_name=dataset_name,
trainset=trainset, testset=testset,
transformer_clf_steps=bert_clf_steps)
strategy_gpu_id = arg_dict["strategy_gpu_id"]
if strategy_name == "RandomStrategy":
self._strategy = RandomStrategy(
arg_dict, dataset_name, strategy_gpu_id, output_dir,
self._metric_bundle, field=field)
if strategy_name == "IdentityStrategy":
self._strategy = IdentityStrategy(
arg_dict, dataset_name, strategy_gpu_id, output_dir,
self._metric_bundle, field=field)
if strategy_name == "TextAttackStrategy":
self._strategy = TextAttackStrategy(
arg_dict, dataset_name, strategy_gpu_id, output_dir,
self._metric_bundle, field=field)
if strategy_name == "ASRSStrategy":
self._strategy = ASRSStrategy(
arg_dict, dataset_name, strategy_gpu_id, output_dir,
self._metric_bundle, field=field)
if self._strategy is None:
logger.error("unknown strategy name %s." % strategy_name)
raise RuntimeError
self._strategy.fit(trainset)
self._trainset = trainset
self._testset = testset
[docs] def paraphrase(self, data_record, n=20):
"""Paraphrase a given data record.
Args:
data_record (dict): data record to be paraphrased.
n (int): number of paraphrases.
Returns:
* a list of str as paraphrased sentences.
* a list of dict as corresponding metrics.
"""
paraphrases, _ = self._strategy.paraphrase_example(data_record, n)
metrics = []
for item in paraphrases:
metrics.append(self._metric_bundle.measure_example(
data_record[self._field], item, data_record))
return data_record[self._field], paraphrases, metrics
[docs] def paraphrase_a_random_sentence(self, n=20, from_testset=True):
"""Randomly pick one data, then paraphrase it.
Args:
n (int): number of paraphrases.
from_testset (bool): if true, select data from test set, otherwise from training set.
Returns:
* a str as the original text.
* a list of str as the paraphrased text.
* a list of dict as corresponding metrics.
"""
dataset = self._testset if from_testset else self._trainset
data_record = np.random.choice(dataset["data"])
_, paraphrases, metrics = self.paraphrase(data_record, n=n)
return data_record[self._field], paraphrases, metrics
[docs] def get_metric_bundle(self):
""""Get the metric bundle."""
return self._metric_bundle