fibber.paraphrase_strategies.asrs_strategy module¶
-
class
fibber.paraphrase_strategies.asrs_strategy.
ASRSStrategy
(arg_dict, dataset_name, strategy_gpu_id, output_dir, metric_bundle, field)[source]¶ Bases:
fibber.paraphrase_strategies.strategy_base.StrategyBase
Initialize the paraphrase_strategies.
This function initialize the
self._strategy_config
,self._metric_bundle
,self._device
,self._output_dir
,self._dataset_name
.You should not overwrite this function.
self._strategy_config (dict): a dictionary that stores the strategy name and all hyperparameter values. The dict is also saved to the results.
self._metric_bundle (MetricBundle): the metrics that will be used to evaluate paraphrases. Strategies can compute metrics during paraphrasing.
self._device (torch.Device): any computation that requires a GPU accelerator should use this device.
self._output_dir (str): the dir name where the strategy can save files.
self._dataset_name (str): the dataset name.
- Parameters
arg_dict (dict) – all args load from command line.
dataset_name (str) – the name of the dataset.
strategy_gpu_id (int) – the gpu id to run the strategy.
output_dir (str) – a directory to save any models or temporary files.
metric_bundle (MetricBundle) – a MetricBundle object.
-
fit
(trainset)[source]¶ Fit the paraphrase strategy on a training set.
- Parameters
trainset (dict) – a fibber dataset.
-
paraphrase_example
(data_record, n, early_stop=False)[source]¶ Paraphrase one data record.
This function should be overwritten by subclasses. When overwriting this class, you can use
self._strategy_config
,self._metric_bundle
,self._device
,self._output_dir
, andself._dataset_name
- Parameters
data_record (dict) – a dict storing one data of a dataset.
n (int) – number of paraphrases.
- Returns
A list contain at most n strings.
- Return type
([str,])
-
fibber.paraphrase_strategies.asrs_strategy.
all_accept_criteria
(candidate_ids, stats, **kwargs)[source]¶ Always accept proposed words.
- Parameters
candidate_ids (torch.Tensor) – proposed word ids in this sampling step with size
(batch_size, pos_ed-pos_st)
.stats (dict) – a dict to keep track the accept rate.
- Returns
- (np.array, None)
np.array is the same as candidate_ids. None means this criteria does not have any state.
-
fibber.paraphrase_strategies.asrs_strategy.
clf_criteria_score
(origin, paraphrases, data_record, field, clf_metric, clf_weight)[source]¶
-
fibber.paraphrase_strategies.asrs_strategy.
joint_weighted_criteria
(tokenizer, data_record, field, origin, batch_tensor, pos_st, pos_ed, previous_ids, candidate_ids, sim_metric, sim_threshold, sim_weight, clf_metric, clf_weight, ppl_metric, ppl_weight, burnin_weight, stats, state, device, seq_len, log_prob_previous_ids, log_prob_candidate_ids, **kwargs)[source]¶ Accept or reject candidate word using the joint weighted criteria.
- Parameters
tokenizer (transformers.BertTokenizer) – a bert tokenizer.
data_record (dict) – the data record dict.
field (str) – the field to rewritten.
origin (str) – original text. Same as
data_record[field]
.batch_tensor (torch.Tensor) – tensor of a batch of text with size
(batch_size, L)
.pos_st (int) – the start position of sampling (include).
pos_ed (int) – the end position of sampling (exclude).
previous_ids (torch.Tensor) – word ids before current step of sampling with size
(batch_size, pos_ed-pos_st)
.candidate_ids (torch.Tensor) – proposed word ids in this sampling step with size
(batch_size, pos_ed-pos_st)
.sim_metric (USESimilarityMetric) – a universal sentence encoder metric object.
sim_threshold (float) – the universal sentence encoder similarity threshold.
sim_weight (float) – the weight for USE criteria score.
clf_metric (BertClassifier) – a BertClassifier metric.
clf_weight (float) – the weight for BERT criteria score.
ppl_metric (GPT2PerplexityMetric) – a GPT2PerplexityMetric metric.
ppl_weight (float) – the weight for GPT2 criteria score.
burnin_weight (float) – the discount factor.
stats (dict) – a dict to keep track the accept rate.
state (np.array) – the state is criteria score from the previous iteration.
seq_len (np.array) – the valid length for each sentence in the batch.
device (torch.Device) – the device that batch_tensor is on.
- Returns
- (np.array, np.array)
- a 2-D int array of size
batch_size, pos_ed - pos_st
. Each rowi
is either
previous_ids[i, :]
if rejected, orcandidate_ids[i, :]
if accepted.
a 1-D float array of criteria score.
- a 2-D int array of size
-
fibber.paraphrase_strategies.asrs_strategy.
ppl_criteria_score
(origin, paraphrases, ppl_metric, ppl_weight)[source]¶ Estimate the score of a sentence using USE.
- Parameters
origin (str) – original sentence.
paraphrases ([str]) – a list of paraphrase_list.
ppl_metric (GPT2PerplexityMetric) – a GPT2PerplexityMetric metric object.
ppl_weight (float) – the weight parameter for the criteria.
- Returns
a numpy array of size
(batch_size,)
. All entries<=0
.- Return type
(np.array)
-
fibber.paraphrase_strategies.asrs_strategy.
process_text
(text, patterns)[source]¶ Processing the text using regex patterns.
- Parameters
text (str) – the str to be post processed.
patterns (list) – a list of substitution patterns.
-
fibber.paraphrase_strategies.asrs_strategy.
sample_word_from_logits
(logits, temperature=1.0, top_k=0)[source]¶ Sample a word from a distribution.
- Parameters
logits (torch.Tensor) – tensor of logits with size
(batch_size, vocab_size)
.temperature (float) – the temperature of softmax. The PMF is
softmax(logits/temperature)
.top_k (int) – if
k>0
, only sample from the top k most probable words.
-
fibber.paraphrase_strategies.asrs_strategy.
sim_criteria_score
(origin, paraphrases, sim_metric, sim_threshold, sim_weight)[source]¶ Estimate the score of a sentence using USE.
- Parameters
origin (str) – original sentence.
paraphrases ([str]) – a list of paraphrase_list.
sim_metric (MetricBase) – a similarity metric object.
sim_threshold (float) – the universal sentence encoder similarity threshold.
sim_weight (float) – the weight parameter for the criteria.
- Returns
a numpy array of size
(batch_size,)
. All entries<=0
.- Return type
(np.array)