fibber.metrics.classifier.classifier_base module

class fibber.metrics.classifier.classifier_base.ClassifierBase(field, bs=32, **kwargs)[source]

Bases: fibber.metrics.metric_base.MetricBase

Base class for classifiers.

All classifiers must be derived from this class.

To implement a new classifier, you should at least overwrite the predict_log_dist_example method. This method returns a predicted logits over classes.

Some classifiers output label instead of distribution. In this case, you should return a one-hot vector.

Some classifier may run more efficiently on a batch of data. In this case, you should overwrite the predict_log_dist_batch function. If you don’t overwrite predict_log_dist_batch, it will compute the metric of paraphrase_list one by one.

load_robust_tuned_model(save_path)[source]
predict_batch(origin, paraphrase_list, data_record=None)[source]

Predict class label for one example.

Parameters
  • origin (str) – the original text.

  • paraphrase_list (list) – a set of paraphrase_list.

  • data_record (dict) – the corresponding data record of original text.

Returns

predicted label as an numpy array of size (batch_size).

Return type

(np.array)

predict_example(origin, paraphrase, data_record=None)[source]

Predict class label for one example.

Parameters
  • origin (str) – the original text.

  • paraphrase (list) – a set of paraphrase_list.

  • data_record (dict) – the corresponding data record of original text.

Returns

predicted label

Return type

(np.int)

predict_log_dist_batch(origin, paraphrase_list, data_record=None)[source]

Predict the log-probability distribution over classes for one batch.

Parameters
  • origin (str) – the original text.

  • paraphrase_list (list) – a set of paraphrase_list.

  • data_record (dict) – the corresponding data record of original text.

Returns

a numpy array of size (batch_size * num_labels).

Return type

(np.array)

predict_log_dist_example(origin, paraphrase, data_record=None)[source]

Predict the log-probability distribution over classes for one example.

Parameters
  • origin (str) – the original text.

  • paraphrase (list) – a set of paraphrase_list.

  • data_record (dict) – the corresponding data record of original text.

Returns

a numpy array of size (num_labels).

Return type

(np.array)

predict_log_dist_multiple_examples(origin_list, paraphrase_list, data_record_list=None)[source]
predict_multiple_examples(origin_list, paraphrase_list, data_record_list=None)[source]
robust_tune_init(optimizer, lr, weight_decay, steps)[source]
robust_tune_step(data_record_list)[source]
save_robust_tuned_model(load_path)[source]