fibber.metrics.classifier.classifier_base module¶
-
class
fibber.metrics.classifier.classifier_base.
ClassifierBase
(field, bs=32, **kwargs)[source]¶ Bases:
fibber.metrics.metric_base.MetricBase
Base class for classifiers.
All classifiers must be derived from this class.
To implement a new classifier, you should at least overwrite the
predict_log_dist_example
method. This method returns a predicted logits over classes.Some classifiers output label instead of distribution. In this case, you should return a one-hot vector.
Some classifier may run more efficiently on a batch of data. In this case, you should overwrite the
predict_log_dist_batch
function. If you don’t overwrite predict_log_dist_batch, it will compute the metric of paraphrase_list one by one.-
predict_batch
(origin, paraphrase_list, data_record=None)[source]¶ Predict class label for one example.
- Parameters
origin (str) – the original text.
paraphrase_list (list) – a set of paraphrase_list.
data_record (dict) – the corresponding data record of original text.
- Returns
predicted label as an numpy array of size
(batch_size)
.- Return type
(np.array)
-
predict_example
(origin, paraphrase, data_record=None)[source]¶ Predict class label for one example.
- Parameters
origin (str) – the original text.
paraphrase (list) – a set of paraphrase_list.
data_record (dict) – the corresponding data record of original text.
- Returns
predicted label
- Return type
(np.int)
-
predict_log_dist_batch
(origin, paraphrase_list, data_record=None)[source]¶ Predict the log-probability distribution over classes for one batch.
- Parameters
origin (str) – the original text.
paraphrase_list (list) – a set of paraphrase_list.
data_record (dict) – the corresponding data record of original text.
- Returns
a numpy array of size
(batch_size * num_labels)
.- Return type
(np.array)
-
predict_log_dist_example
(origin, paraphrase, data_record=None)[source]¶ Predict the log-probability distribution over classes for one example.
- Parameters
origin (str) – the original text.
paraphrase (list) – a set of paraphrase_list.
data_record (dict) – the corresponding data record of original text.
- Returns
a numpy array of size
(num_labels)
.- Return type
(np.array)
-