0% found this document useful (0 votes)
36 views

grpo_cls

The document presents a framework for debiased multi-label classification using a GRPO with DPO-inspired loss. It defines a problem setup involving privileged and non-privileged labels, introduces loss functions for both groups, and outlines an algorithm for classification. Additionally, it provides examples of potential datasets for image and text classification tasks.

Uploaded by

srijan jha
Copyright
© © All Rights Reserved
We take content rights seriously. If you suspect this is your content, claim it here.
Available Formats
Download as PDF, TXT or read online on Scribd
0% found this document useful (0 votes)
36 views

grpo_cls

The document presents a framework for debiased multi-label classification using a GRPO with DPO-inspired loss. It defines a problem setup involving privileged and non-privileged labels, introduces loss functions for both groups, and outlines an algorithm for classification. Additionally, it provides examples of potential datasets for image and text classification tasks.

Uploaded by

srijan jha
Copyright
© © All Rights Reserved
We take content rights seriously. If you suspect this is your content, claim it here.
Available Formats
Download as PDF, TXT or read online on Scribd
You are on page 1/ 6

GRPO with DPO Loss for Debiased Multi-Label Classification

Soumen and Akshit

April 9, 2025

1 Problem Setup: Debiased Multi-Label Classification


We consider a multi-label classification problem with T possible labels, T = {1, 2, . . . , T }. We have
a dataset D = {(xi , yi )}N
i=1 , where xi is an input and yi = [yi1 , . . . , yiT ] is the ground truth label
vector, with yil = +1 if label l applies to xi , and yil = 0 otherwise.

The set of labels T is partitioned into a privileged set P ⊂ T and a non-privileged set P̄ = T \ P.
We assume a policy model πθ (y|x) where θ is pretrained model parameters. We wish to finetune
the pretrained policy model by introducing policy parameters wt that produces a probability score
for each label t given input x, denoted as m(x; wt ). We also have access to a reference model with
parameters ŵt , producing scores m(x; ŵt ).

For a ground truth positive label l (yil = +1), we define the set of confusing negative labels for
instance xi as:
Sil = {k ∈ T | yik = 0 and m(xi ; wk ) ≥ m(xi ; wl )} (1)
The fairness goals are:
1. Privileged (P): For a true positive label l ∈ P, the model should strongly distinguish it from
its confusing negatives k ∈ Sil . We want m(xi ; wl ) to be significantly higher than m(xi ; wk ).

2. Non-Privileged (P̄): For a label j ∈ P̄, the classification performance under the learned
model wj should remain close to the performance under the reference model ŵj . We desire
loss(xi ; wj ) ≤ loss(xi ; ŵj ) + ϵ. we could
use −ϵ if
we want
2 Connecting DPO and the Privileged Goal our model
Direct Preference Optimization (DPO) learns from preferences yp ≻ yd where yp is preferred re- to per-
sponse and yd is dispreferred response for an input x. We can reframe the privileged goal as form even
expressing a preference: for a given instance xi where yil = +1 and l ∈ P, we prefer label l over any better on
confusing negative label k ∈ Sil . DPO compares the relative log-probability assigned by the current non privi-
(w) vs. reference policy (ŵ). Define a DPO-like term comparing the preferred label l (winner) lege set.
against the confusing label k (loser) for input xi :
   
m(xi ; wl ) m(xi ; wk )
hw (xi , l, k) = log − log (2)
m(xi ; ŵl ) m(xi ; ŵk )
| {z } | {z }
likelihood for preferred l likelihood for dispreferred k

m(xi ; wt ) = σ(wtT zi ) where, zi = πθ (xi ) ∈ Rd (3)

1
This term hw measures how much more the current model w separates label l from label k compared
to how the reference model ŵ separates them. We want hw to be as large as possible. If πθ is an LLM
then πθ (xi ) could be the output embedding of last token position (d is the embedding dimension).
The embedding vector could then be passed to an FFNN classifier whose number of output neurons
would be same as number of labels T . In this context, wt represents the weight vector attached to
the t-th label from the input z. ŵ could be obtained by performing SFT (Supervised finetuning)
on πθ using multilabel classification objective using the multilabel classification dataset.

3 GRPO Framework with DPO-Inspired Loss


We apply the GRPO framework by defining groups and losses based on the privileged/non-privileged
distinction.

3.1 Group Definition


• Group P (GP ): Consists of triplets (xi , l, k) where l ∈ P, yil = +1, and k ∈ Sil . This group
focuses on the DPO-like objective for privileged labels.

• Group P̄ (GP̄ ): Consists of pairs (xi , j) where j ∈ P̄. This group focuses on the margin
based loss constraint for non-privileged labels.

3.2 Loss Functions


Privileged Loss (LP ): We apply the structure of the DPO loss using our defined term hw (xi , l, k).
We minimize the negative log-likelihood of the preference l ≻ k.

LP ({wt |∃t ∈ T }, {ŵt |∃t ∈ T }) = E(xi , l, k) s.t. l ∈ P, yil = +1, k ∈ Sil [− log σ (β · hw (xi , l, k))] (4)

Minimizing this loss encourages hw (xi , l, k) to be large and positive, effectively pushing m(xi ; wl )
higher relative to m(xi ; wk ) compared to the reference model.

Non-Privileged Loss (LP̄ ): This loss enforces the constraint on standard classification perfor-
mance for non-privileged labels. We use a base binary cross-entropy loss function.

LP̄ ({wt |∃t ∈ T }, {ŵt |∃t ∈ T }) = E(xi , {yij |j ∈ P̄}) s.t. yij ∈ {+1, 0} [Lh (wj , ŵj ; ϵ)] (5)
Lh (wj , ŵj ; ϵ) = max (0, loss(wj ) − loss(ŵj ) − ϵ) (6)
loss(wj ) = −yij log(m(xi ; wj )) − (1 − yij ) log(1 − m(xi ; wj )) (7)
loss(ŵj ) = −yij log(m(xi ; ŵj )) − (1 − yij ) log(1 − m(xi ; ŵj )) (8)

The non-privileged loss penalizes the model if the average performance on a non-privileged label j
degrades significantly compared to the reference model.

3.3 GRPO-DPO Objective


The objective is the GRPO minimax formulation using these specific losses:

min max [αP LP (w, ŵ) + αP̄ LP̄ (w, ŵ)] (9)
w αP ,αP̄ ≥ 0, αP +αP̄ = 1

2
4 GRPO-DPO Algorithm for Classification

Algorithm 1 GRPO-DPO for Debiased Multi-Label Classification


(0)
1: Initialize: model parameters {wt ∈ Rd |∀t ∈ T } (e.g., copy {ŵt ∈ Rd |∀t ∈ T }), group weights
(0) (0)
αP ← 0.5, αP̄ ← 0.5.
2: Choose: learning rates ηw , ηα , DPO hyperparameter β, reference parameters {ŵt |∀t ∈ T },
constraint slack ϵ.
3: for s = 0 to S (MaxIterations) do
4: Sample an example (xi , [yi1 , yi2 , . . . , yiT ]) ∼ D.
5: t ←⃗
Initialize gradients: gP 0 ∀t ∈ T .
(s) (s)T
6: Forward pass: m(xi ; wt ) ← σ(wt zi ) where, zi ← πθ (xi ) ∀t ∈ T .
7: Sample a label: r ∈ T ∼ U nif orm( |T1 | ).
8: if r ∈ P and yir = +1 then
9: l←r
(s) (s)
10: Sil ← {k ∈ T | yik = 0 and m(xi ; wk ) ≥ m(xi ; wl )} ▷ Based on current scores
11: Sample a confusing label: k ∈ Sil ∼ U nif orm( |S1il | )
 (s)
  (s)

m(xi ;wl ) m(xi ;wk )
12: hw(s) (xi , l, k) ← log m(xi ;ŵl ) − log m(xi ;ŵk ) .
(s)
13: LP ← − log σ (β · hw(s) (xi , l, k)) ▷ Privileged loss
t ← g t + ∇ L(s) |
14: gP P wt P w(s) ∀t ∈ T . ▷ w.r.t. wl and wk , gradient will be non-zero
t
15: else if r ∈ P̄ and yir ∈ {+1, 0} then
16: j←r
17: t ←⃗
Initialize gradients: gP̄ 0 ∀t ∈ T .
(s) (s) (s)
18: loss(wj ) ← −yij log(m(xi ; wj )) − (1 − yij ) log(1 − m(xi ; wj )
19: loss(ŵj ) ← 
−yij log(m(xi ; ŵj )) − (1 − yij) log(1 − m(xi ; ŵj ))
(s) (s)
20: LP̄ ← max 0, loss(wj ) − loss(ŵj ) − ϵ ▷ Non-privileged loss
t ← gt + ∇ L | (s)
21: gP̄ P̄ wt P̄ w(s) ∀t ∈ T . ▷ w.r.t. wj , gradient will be non-zero
t
22: else
23: t ←⃗
gP t ←⃗
0, gP̄ 0 ∀t ∈ T .
24: end if
(s+1) (s) (s) (s+1) (s) (s)
25: αP ← αP exp(ηα LP ) and αP̄ ← αP̄ exp(ηα LP̄ ) ▷ Mirror ascent step
(s+1) (s+1)
26: Z ← αP + αP̄
(s+1) (s+1)
(s+1) αP (s+1) αP̄
27: αP ←Z and αP̄ ← Z ▷ Weight normalization step
(s+1) (s) (s+1) t (s+1) t
28: wt ← wt− ηw (αP gP + αP̄ gP̄ ) ∀t ∈ T ▷ Mirror descent step
29: end for
(S)
30: return {wt |∀t ∈ T }

5 Possible Image Datasets


NIH Chest X-ray14
Description: A large dataset of chest X-ray images, where each image is labeled with potentially
multiple common thoracic pathologies. This is inherently a multi-label classification task.

3
Labels (T ): {Atelectasis, Cardiomegaly, Effusion, Infiltration, Mass, Nodule, Pneumonia, Pneu-
mothorax, Consolidation, Edema, Emphysema, Fibrosis, Pleural Thickening, Hernia} (T=14)

Privileged Labels (P): {Pneumothorax, Mass, Nodule, Hernia} (Chosen based on clinical sever-
ity/urgency and potentially lower frequency compared to findings like Infiltration or Effusion).

Example (xi , yi ):
• xi : A frontal chest X-ray image.
• yi : [Atelectasis=0, Cardiomegaly=1, Effusion=1, Infiltration=0, Mass=0, Nodule=0,
Pneumonia=0, Pneumothorax=0, Consolidation=0, Edema=0, Emphysema=0, Fibrosis=0,
Pleural Thickening=0, Hernia=0] (Image shows cardiomegaly and effusion).

EuroSAT
Description: Satellite images from Sentinel-2 covering various land use/land cover classes across
Europe. This is typically used as a multi-class classification task (each image belongs to one class).

Labels (T ): {Annual Crop, Forest, Herbaceous Vegetation, Highway, Industrial, Pasture, Per-
manent Crop, Residential, River, Sea Lake} (T=10)

Privileged Labels (P): {Highway, Industrial, Sea Lake} (Chosen as potentially less frequent or
visually confusable classes compared to broader categories like Forest or Residential).

Example (xi , yi ):
• xi : A Sentinel-2 satellite image patch.
• yi : [Annual Crop=0, Forest=1, Herbaceous Vegetation=0, Highway=0, Industrial=0,
Pasture=0, Permanent Crop=0, Residential=0, River=0, Sea Lake=0] (Image shows a
forested area).

CIFAR-100
Description: A dataset of 60,000 32x32 color images in 100 fine-grained classes, grouped into 20
coarse-grained classes. We consider the 100 fine-grained classes for a multi-class task.

Labels (T ): {apple, aquarium fish, ..., worm} (100 classes total, e.g., beaver, dolphin, otter, seal,
whale are all ’aquatic mammals’).

Privileged Labels (P): {otter, seal, rabbit, squirrel} (Chosen as examples of fine-grained classes
that might be easily confused with other similar animals within the same coarse category, requiring
better distinction).

Example (xi , yi ):
• xi : A 32x32 color image.
• yi : A 100-element binary vector with a single ’1’ at the index corresponding to the true class.
For instance, if the image is a ’seal’: [. . . , otter = 0, . . . , seal = 1, . . . , whale = 0, . . . ]

4
6 Possible Text Datasets
MovieLens Tag Genome (Conceptual - for Genre Prediction)
Description: While Tag Genome provides fine-grained tags, we conceptualize a dataset derived
from it or similar movie databases (like IMDb plot summaries) for multi-label genre prediction
based on plot text.

Labels (T ): {Action, Adventure, Animation, Comedy, Crime, Drama, Fantasy, Horror, Mystery,
Romance, Sci-Fi, Thriller} (T=12)

Privileged Labels (P): {Animation, Fantasy, Sci-Fi} (Chosen as genres that might be less
frequent than Drama/Comedy or have overlapping themes with others, e.g., Fantasy/Adventure).

Example (xi , yi ):

• xi : ”A young farm boy discovers a hidden destiny and joins a group of rebels to battle an
evil empire ruling the galaxy, learning the ways of a mystical energy field.”

• yi : [Action=1, Adventure=1, Animation=0, Comedy=0, Crime=0, Drama=0, Fantasy=1,


Horror=0, Mystery=0, Romance=0, Sci-Fi=1, Thriller=0] (Describes Star Wars: A New
Hope).

TweetEval (Sentiment or Emotion subset adapted for Multi-Label)


Description: TweetEval benchmarks various NLP tasks on Tweets. We can adapt a subset, e.g.,
emotion recognition, to a multi-label format or define a custom multi-label task for social media
post classification (e.g., content type).

Labels (T ): {News, Opinion, Question, Personal Update, Promotion, Complaint, Recommenda-


tion} (T=7)

Privileged Labels (P): {Question, Promotion, Complaint} (Chosen because identifying these
specific intents can be crucial for engagement, moderation, or business intelligence, and they might
be less common than Opinions or Updates).

Example (xi , yi ):

• xi : ”Anyone know a good place for brunch near downtown? Preferably not too expensive!
#foodie #help”

• yi : [News=0, Opinion=0, Question=1, Personal Update=0, Promotion=0, Complaint=0,


Recommendation=0] (Asking for a recommendation).

MultiNLI (Multi-Genre Natural Language Inference)


Description: A dataset for Natural Language Inference. Given a premise sentence and a hy-
pothesis sentence, the task is to predict whether the premise entails the hypothesis, contradicts it,
or neither (neutral). This is a 3-class classification problem.

5
Labels (T ): {Entailment, Neutral, Contradiction} (T=3)

Privileged Labels (P): {Contradiction} (Models often find it harder to reliably detect contra-
diction compared to entailment, making it a candidate for privileged status to improve performance
on this specific relation).

Example (xi , yi ):

• xi : Premise: ”The cat sat on the mat.” Hypothesis: ”The cat was sleeping under the table.”

• yi : [Entailment=0, Neutral=0, Contradiction=1]

You might also like