Skip to content

[ML] add zero_shot_classification task for BERT nlp models #77799

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Conversation

benwtrent
Copy link
Member

@benwtrent benwtrent commented Sep 15, 2021

Zero-Shot classification allows for text classification tasks without a pre-trained collection of target labels.

This is achieved through models trained on the Multi-Genre Natural Language Inference (MNLI) dataset. This dataset pairs text sequences with "entailment" clauses. An example could be:

"Throughout all of history, man kind has shown itself resourceful, yet astoundingly short-sighted" could have been paired with the entailment clauses: ["This example is history", "This example is sociology"...].

This training set combined with the attention and semantic knowledge in modern day NLP models (BERT, BART, etc.) affords a powerful tool for ad-hoc text classification.

See https://arxiv.org/abs/1909.00161 for a deeper explanation of the MNLI training and how zero-shot works.

The zeroshot classification task is configured as follows:

{
   // <snip> model configuration </snip>
  "inference_config" : {
    "zero_shot_classification": {
      "classification_labels": ["entailment", "neutral", "contradiction"], // <1>
      "labels": ["sad", "glad", "mad", "rad"], // <2>
      "multi_label": false, // <3>
      "hypothesis_template": "This example is {}.", // <4>
      "tokenization": { /*<snip> tokenization configuration </snip>*/}
    }
  }
}
  • <1> For all zero_shot models, there returns 3 particular labels when classification the target sequence. "entailment" is the positive case, "neutral" the case where the sequence isn't positive or negative, and "contradiction" is the negative case
  • <2> This is an optional parameter for the default zero_shot labels to attempt to classify
  • <3> When returning the probabilities, should the results assume there is only one true label or multiple true labels
  • <4> The hypothesis template when tokenizing the labels. When combining with sad the sequence looks like This example is sad.

For inference in a pipeline one may provide label updates:

{
  //<snip> pipeline definition </snip>
  "processors": [
    //<snip> other processors </snip>
    {
      "inference": {
        // <snip> general configuration </snip>
        "inference_config": {
          "zero_shot_classification": {
             "labels": ["humanities", "science", "mathematics", "technology"], // <1>
             "multi_label": true // <2>
          }
        }
      }
    }
    //<snip> other processors </snip>
  ]
}
  • <1> The labels we care about, these replace the default ones if they exist.
  • <2> Should the results allow multiple true labels

Similarly one may provide label changes against the _infer endpoint

{
   "docs":[{ "text_field": "This is a very happy person"}],
   "inference_config":{"zero_shot_classification":{"labels": ["glad", "sad", "bad", "rad"], "multi_label": false}}
}

@benwtrent benwtrent force-pushed the feature/ml-add-zero-shot-classification-support branch 2 times, most recently from 1b308e7 to d3118d9 Compare September 16, 2021 12:14
@benwtrent
Copy link
Member Author

run elasticsearch-ci/part-2

@benwtrent benwtrent force-pushed the feature/ml-add-zero-shot-classification-support branch 3 times, most recently from 6eacd1a to 1da856b Compare September 21, 2021 12:27
@benwtrent benwtrent force-pushed the feature/ml-add-zero-shot-classification-support branch from 1da856b to 1974ec2 Compare September 21, 2021 19:28
@benwtrent benwtrent marked this pull request as ready for review September 22, 2021 11:27
@elasticmachine elasticmachine added the Team:ML Meta label for the ML team label Sep 22, 2021
@benwtrent benwtrent changed the title [ML] add zero_shot_classification task for nlp models [ML] add zero_shot_classification task for BERT nlp models Sep 23, 2021
Copy link
Member

@davidkyle davidkyle left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, only some minor comments


@Override
public InferenceConfig toConfig() {
throw new UnsupportedOperationException("cannot serialize to nodes before 7.8");
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure what the reason for this is. I'm guessing the error message is a copy-paste, if the idea is that implementing classes should implement this method then remove this and let the compiler do its work.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@davidkyle we cannot create an NLP inference config from an update. In a separate PR I am gonna remove this check and this method as it is not used in master (so, this is an intermediate change really).

.orElse(new VocabularyConfig(InferenceIndexConstants.nativeDefinitionStore()));
this.tokenization = tokenization == null ? Tokenization.createDefault() : tokenization;
this.isMultiLabel = isMultiLabel != null && isMultiLabel;
this.hypothesisTemplate = Optional.ofNullable(hypothesisTemplate).orElse(DEFAULT_HYPOTHESIS_TEMPLATE);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we allow labels to be null perhaps hypothesisTemplate should not be defaulted so that both must be defined at point of call.

Suggestion: rename labels to hypothesisLabels

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rename labels to hypothesisLabels

I don't think anybody calls them that.

another good option might be target_labels?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we allow labels to be null perhaps hypothesisTemplate should not be defaulted so that both must be defined at point of call.

The typical default (for MNLI trained models) is the one we are providing. This is a nice quality of life improvement I think.

I also don't think that labels should be required on creation. When the user is putting the model with its config, they have no idea what labels the model user will use. I think allowing null enforces that the person using the model has to provide the labels they want.

The whole point of zero_shot is that you don't know/care about the labels until you call infer.


@Override
public ZeroShotClassificationConfigUpdate.Builder setResultsField(String resultsField) {
throw new IllegalArgumentException();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems a little harsh I use this for the regression/classification models

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, we don't even have a results field in the original NLP configs.

I think this is a larger discussion around unifying the NLP Configs with the classification/regression configs.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should really be part of the inference processor config. Looking at the code this function appears to exist for use by InferencePipelineAggregationBuilder. ++ to revisiting this and simplifying the code.

@benwtrent benwtrent requested a review from davidkyle September 27, 2021 17:08
@@ -414,6 +414,68 @@ include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenizati
(Optional, integer)
include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-bert-max-sequence-length]

`with_special_tokens`::::
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@szabosteve mind taking a look at the doc changes?

Copy link
Contributor

@szabosteve szabosteve left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Documentation LGTM! Thanks for writing the content! ✍️ I left three suggestions. This one will make the docs CI pass.

it is possible to adjust the labels to classify. This makes this type of model
and task exceptionally flexible.

If consistently classifying the same labels, it may be better to use an optimized
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
If consistently classifying the same labels, it may be better to use an optimized
If consistently classifying the same labels, it may be better to use an fine tuned

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you accept this suggestion, please also change the indefinite article to a from an.

);
}
final double[] normalizedScores;
if (isMultiLabel) {
Copy link
Member

@davidkyle davidkyle Sep 28, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

isMultiLabel is just about how the results are interpreted? When true the probability of entailment as opposed to contradiction for each label is returned. When false it is the probability of each label being entailment. Can you help me understand this and update the docs

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

isMultiLabel is basically:

  • When true its softmax of individual entailment vs contradiction (probs don't sum to 1.0)
  • When false its softmax of all entailments (probs sum to 1.0)

The docs already state you use it when you could have more than one true label. Which is exactly what we use it for.

I would rather not talk about softmax, entailment, etc.


@Override
public ZeroShotClassificationConfigUpdate.Builder setResultsField(String resultsField) {
throw new IllegalArgumentException();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should really be part of the inference processor config. Looking at the code this function appears to exist for use by InferencePipelineAggregationBuilder. ++ to revisiting this and simplifying the code.

benwtrent and others added 2 commits September 28, 2021 07:20
Co-authored-by: David Kyle <[email protected]>
Co-authored-by: István Zoltán Szabó <[email protected]>
@benwtrent benwtrent requested a review from davidkyle September 28, 2021 12:09
Copy link
Member

@davidkyle davidkyle left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@benwtrent benwtrent merged commit 4084893 into elastic:master Sep 28, 2021
@benwtrent benwtrent deleted the feature/ml-add-zero-shot-classification-support branch September 28, 2021 13:38
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
:ml Machine learning >non-issue Team:ML Meta label for the ML team v8.0.0-beta1
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants