fibber.metrics.classifier.transformer_classifier module¶
This metric outputs a transformer-based classifier prediction of the paraphrased text.
-
class
fibber.metrics.classifier.transformer_classifier.
TransformerClassifier
(dataset_name, trainset, testset, transformer_clf_gpu_id=- 1, transformer_clf_steps=20000, transformer_clf_bs=32, transformer_clf_lr=2e-05, transformer_clf_optimizer='adamw', transformer_clf_weight_decay=0.001, transformer_clf_period_summary=100, transformer_clf_period_val=500, transformer_clf_period_save=20000, transformer_clf_val_steps=10, transformer_clf_model_init='bert-base-cased', **kwargs)[source]¶ Bases:
fibber.metrics.classifier.classifier_base.ClassifierBase
BERT classifier prediction on paraphrase_list.
This metric is special, it does not compare the original and paraphrased sentence. Instead, it outputs the classifier prediction on paraphrase_list. So we should not compute mean or std on this metric.
- Parameters
dataset_name (str) – the name of the dataset.
trainset (dict) – a fibber dataset.
testset (dict) – a fibber dataset.
transformer_clf_gpu_id (int) – the gpu id for BERT model. Set -1 to use CPU.
transformer_clf_steps (int) – steps to train a classifier.
transformer_clf_bs (int) – the batch size.
transformer_clf_lr (float) – the learning rate.
transformer_clf_optimizer (str) – the optimizer name.
transformer_clf_weight_decay (float) – the weight decay in the optimizer.
transformer_clf_period_summary (int) – the period in steps to write training summary.
transformer_clf_period_val (int) – the period in steps to run validation and write validation summary.
transformer_clf_period_save (int) – the period in steps to save current model.
transformer_clf_val_steps (int) – number of batched in each validation.
-
fibber.metrics.classifier.transformer_classifier.
get_optimizer
(optimizer_name, lr, decay, train_step, params, warmup=1000)[source]¶ Create an optimizer and schedule of learning rate for parameters.
- Parameters
optimizer_name (str) – choose from
["adam", "sgd", "adamw"]
.lr (float) – learning rate.
decay (float) – weight decay.
train_step (int) – number of training steps.
params (list) – a list of parameters in the model.
warmup (int) – number of warm up steps.
- Returns
A torch optimizer and a scheduler.
-
fibber.metrics.classifier.transformer_classifier.
load_or_train_transformer_clf
(model_init, dataset_name, trainset, testset, transformer_clf_steps, transformer_clf_bs, transformer_clf_lr, transformer_clf_optimizer, transformer_clf_weight_decay, transformer_clf_period_summary, transformer_clf_period_val, transformer_clf_period_save, transformer_clf_val_steps, device)[source]¶ Train transformer-based classification model on a dataset.
The trained model will be stored at
<fibber_root_dir>/transformer_clf/<dataset_name>/
. If there’s a saved model, load and return the model. Otherwise, train the model using the given data.- Parameters
model_init (str) – pretrained model name. e.g.
["bert-base-cased", "bert-base-uncased", "bert-large-cased", "bert-large-uncased", "roberta-large"]
.dataset_name (str) – the name of the dataset. This is also the dir to save trained model.
trainset (dict) – a fibber dataset.
testset (dict) – a fibber dataset.
transformer_clf_steps (int) – steps to train a classifier.
transformer_clf_bs (int) – the batch size.
transformer_clf_lr (float) – the learning rate.
transformer_clf_optimizer (str) – the optimizer name.
transformer_clf_weight_decay (float) – the weight decay.
transformer_clf_period_summary (int) – the period in steps to write training summary.
transformer_clf_period_val (int) – the period in steps to run validation and write validation summary.
transformer_clf_period_save (int) – the period in steps to save current model.
transformer_clf_val_steps (int) – number of batched in each validation.
device (torch.Device) – the device to run the model.
- Returns
a torch transformer model.
-
fibber.metrics.classifier.transformer_classifier.
run_evaluate
(model, dataloader_iter, eval_steps, summary, global_step, device, model_init)[source]¶ Evaluate a model and add error rate and validation loss to Tensorboard.
- Parameters
model (transformers.BertForSequenceClassification) – a BERT classification model.
dataloader_iter (torch.IterableDataset) – an iterator of a torch.IterableDataset.
eval_steps (int) – number of training steps.
summary (torch.utils.tensorboard.SummaryWriter) – a Tensorboard SummaryWriter object.
global_step (int) – current training steps.
device (torch.Device) – the device where the model in running on.
model_init (str) – a str specifies the pretrained model. used to determine model input.