-
Notifications
You must be signed in to change notification settings - Fork 25.2k
[ML] add zero_shot_classification task for BERT nlp models #77799
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[ML] add zero_shot_classification task for BERT nlp models #77799
Conversation
1b308e7
to
d3118d9
Compare
run elasticsearch-ci/part-2 |
6eacd1a
to
1da856b
Compare
1da856b
to
1974ec2
Compare
…ro-shot-classification-support
…ro-shot-classification-support
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good, only some minor comments
|
||
@Override | ||
public InferenceConfig toConfig() { | ||
throw new UnsupportedOperationException("cannot serialize to nodes before 7.8"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure what the reason for this is. I'm guessing the error message is a copy-paste, if the idea is that implementing classes should implement this method then remove this and let the compiler do its work.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@davidkyle we cannot create an NLP inference config from an update. In a separate PR I am gonna remove this check and this method as it is not used in master (so, this is an intermediate change really).
...ava/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ZeroShotClassificationConfig.java
Outdated
Show resolved
Hide resolved
.orElse(new VocabularyConfig(InferenceIndexConstants.nativeDefinitionStore())); | ||
this.tokenization = tokenization == null ? Tokenization.createDefault() : tokenization; | ||
this.isMultiLabel = isMultiLabel != null && isMultiLabel; | ||
this.hypothesisTemplate = Optional.ofNullable(hypothesisTemplate).orElse(DEFAULT_HYPOTHESIS_TEMPLATE); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we allow labels
to be null
perhaps hypothesisTemplate
should not be defaulted so that both must be defined at point of call.
Suggestion: rename labels
to hypothesisLabels
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
rename labels to hypothesisLabels
I don't think anybody calls them that.
another good option might be target_labels
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we allow
labels
to be null perhapshypothesisTemplate
should not be defaulted so that both must be defined at point of call.
The typical default (for MNLI trained models) is the one we are providing. This is a nice quality of life improvement I think.
I also don't think that labels
should be required on creation. When the user is putting the model with its config, they have no idea what labels the model user will use. I think allowing null
enforces that the person using the model has to provide the labels they want.
The whole point of zero_shot is that you don't know/care about the labels until you call infer.
...g/elasticsearch/xpack/core/ml/inference/trainedmodel/ZeroShotClassificationConfigUpdate.java
Outdated
Show resolved
Hide resolved
|
||
@Override | ||
public ZeroShotClassificationConfigUpdate.Builder setResultsField(String resultsField) { | ||
throw new IllegalArgumentException(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems a little harsh I use this for the regression/classification models
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Well, we don't even have a results
field in the original NLP configs.
I think this is a larger discussion around unifying the NLP Configs with the classification/regression configs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It should really be part of the inference processor config. Looking at the code this function appears to exist for use by InferencePipelineAggregationBuilder
. ++ to revisiting this and simplifying the code.
...rg/elasticsearch/xpack/core/ml/inference/trainedmodel/ZeroShotClassificationConfigTests.java
Outdated
Show resolved
Hide resolved
.../src/main/java/org/elasticsearch/xpack/ml/inference/nlp/ZeroShotClassificationProcessor.java
Outdated
Show resolved
Hide resolved
.../src/main/java/org/elasticsearch/xpack/ml/inference/nlp/ZeroShotClassificationProcessor.java
Outdated
Show resolved
Hide resolved
…ro-shot-classification-support
@@ -414,6 +414,68 @@ include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenizati | |||
(Optional, integer) | |||
include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-bert-max-sequence-length] | |||
|
|||
`with_special_tokens`:::: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@szabosteve mind taking a look at the doc changes?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Documentation LGTM! Thanks for writing the content! ✍️ I left three suggestions. This one will make the docs CI pass.
docs/reference/ml/df-analytics/apis/put-trained-models.asciidoc
Outdated
Show resolved
Hide resolved
docs/reference/ml/ml-shared.asciidoc
Outdated
it is possible to adjust the labels to classify. This makes this type of model | ||
and task exceptionally flexible. | ||
|
||
If consistently classifying the same labels, it may be better to use an optimized |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If consistently classifying the same labels, it may be better to use an optimized | |
If consistently classifying the same labels, it may be better to use an fine tuned |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you accept this suggestion, please also change the indefinite article to a
from an
.
...ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/BertTokenizerTests.java
Show resolved
Hide resolved
.../src/main/java/org/elasticsearch/xpack/ml/inference/nlp/ZeroShotClassificationProcessor.java
Outdated
Show resolved
Hide resolved
.../src/main/java/org/elasticsearch/xpack/ml/inference/nlp/ZeroShotClassificationProcessor.java
Outdated
Show resolved
Hide resolved
); | ||
} | ||
final double[] normalizedScores; | ||
if (isMultiLabel) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
isMultiLabel
is just about how the results are interpreted? When true the probability of entailment as opposed to contradiction for each label is returned. When false it is the probability of each label being entailment. Can you help me understand this and update the docs
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
isMultiLabel
is basically:
- When
true
itssoftmax
of individual entailment vs contradiction (probs don't sum to 1.0) - When
false
itssoftmax
of all entailments (probs sum to 1.0)
The docs already state you use it when you could have more than one true label. Which is exactly what we use it for.
I would rather not talk about softmax, entailment, etc.
|
||
@Override | ||
public ZeroShotClassificationConfigUpdate.Builder setResultsField(String resultsField) { | ||
throw new IllegalArgumentException(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It should really be part of the inference processor config. Looking at the code this function appears to exist for use by InferencePipelineAggregationBuilder
. ++ to revisiting this and simplifying the code.
.../src/main/java/org/elasticsearch/xpack/ml/inference/nlp/ZeroShotClassificationProcessor.java
Outdated
Show resolved
Hide resolved
Co-authored-by: David Kyle <[email protected]> Co-authored-by: István Zoltán Szabó <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Zero-Shot classification allows for text classification tasks without a pre-trained collection of target labels.
This is achieved through models trained on the Multi-Genre Natural Language Inference (MNLI) dataset. This dataset pairs text sequences with "entailment" clauses. An example could be:
"Throughout all of history, man kind has shown itself resourceful, yet astoundingly short-sighted" could have been paired with the entailment clauses: ["This example is history", "This example is sociology"...].
This training set combined with the attention and semantic knowledge in modern day NLP models (BERT, BART, etc.) affords a powerful tool for ad-hoc text classification.
See https://arxiv.org/abs/1909.00161 for a deeper explanation of the MNLI training and how zero-shot works.
The zeroshot classification task is configured as follows:
sad
the sequence looks likeThis example is sad.
For inference in a pipeline one may provide label updates:
labels
we care about, these replace the default ones if they exist.Similarly one may provide label changes against the
_infer
endpoint