3. Multi-class Classification
3. Multi-class Classification
Yunwen Lei
1 Multiclass Predictors
3 Frank-Wolfe Algorithm
One-vs-One
Train c(c − 1)/2 binary classifiers, each classifier for a pair of different classes
The final prediction by the majority vote
Reduction to Binary Classification
One-vs-Rest
Train c binary classifiers, one for each class
Train j-th classifier to distinguish class j from the rest
Suppose h1 , . . . , hc : X 7→ R are our binary classifiers
The final prediction is
h(x) = arg max hj (x)
j∈[c]
Y = {−1, +1}, X = Rd
Linear classifier score function h(x) = w⊤ x
Final prediction: sign(h(x))
Class 2 vs Rest
▶ Predicts everything to be “Not 2”
▶ It misclassifies 6 examples (all examples in class “2”)
▶ If it predicted some “2”, then it would get many more “Not 2” incorrect
Score for class j is
hj (x) = wj⊤ x = ∥wj ∥2 ∥x∥ cos θj ,
where θj is the angle between x and wj .
One-vs-Rest: Class Boundaries
Idea: use one function h(x, y ) to give a compatibility score between input x and output y
Final prediction is the y ∈ Y that is “most compatible” with x
! !
− √12 √1
0 2
w1 = √1
, w2 = , w3 = √1
2
1 2
Prediction function
x1
x= ← arg max ⟨wj , x⟩
x2 j∈{1,2,3}
How can we get this into the form x ← arg maxy ∈Y ⟨w, Ψ(x, y )⟩?
The Multivector Construction
What if we stack wj ’s together
! !
− √12 √1
0 2
w= √1
, , √1
2
1 2
| {z } | {z } | {z }
w1 w2 w3
w = (w1 , . . . , wc ) ∈ Rd×c .
n
gi
ar
m
Recap: we choose the class label with the largest score function
The prediction is correct if the true class label receives the largest score!
Multi-class Margin
Aim: Want h(xi , yi ) larger than all other h(xi , j) for j ̸= y i.e., a large margin
margin(h, z) = h(x, y ) − max h(x, j)
| {z } j:j̸=y | {z }
score of true label score of incorrect label
Multi-class Margin
Defined as the score difference between the score of the correct label and the highest
score for the incorrect label
Margin-based Loss
where ℓy : Rc 7→ R.
margin: ty − maxj:j̸=y tj
Soft-Max Loss
Note that the margin of t is ty − maxj̸=y tj
Max is not differentiable. Replace it with the soft-max margin
X
max t1 −ty , . . . , tc −ty } ≤ log exp(tj −ty ) ≤ max t1 −ty , . . . , tc −ty }+log c
j:j̸=y
| {z } | {z }
−margin | {z } −margin
soft-max
Examples: Ω(w) = λ
2
∥w∥22,p , where
c
X 1
∥w∥2,p := ∥wj ∥p2 p
p≥1
j=1
Pc
p = 1: ∥w∥2,1 := j=1 ∥wj ∥2
Pc 1
p = 2: ∥w∥2,p := j=1 ∥wj ∥22 2
The difficulty lies in the nonlinearity of ℓy , which takes a vector in Rc as the input!
Lipschitz Continuity of ℓy
Example
The hinge loss ℓy (t1 , . . . , tc ) = max{0, 1 + maxj:j̸=y (tj − tc )} is 2-Lipschitz continuous
w.r.t. ∥ · ∥∞ .
Example
P
The soft-max loss ℓy (t1 , . . . , tc ) = log 1 + j:j̸=y exp(tj − ty ) is 2-Lipschitz continuous
w.r.t. ∥ · ∥∞ .
P
Proof. The function t 7→ g (t) := log j∈[c] exp(tj ) is 1-Lipschitz continuous w.r.t.
∥ · ∥∞ .
exp(t1 )
1 .
∇g (t) = P .. =⇒ ∥∇g (t)∥1 ≤ 1.
j∈[c] exp(t j )
exp(tc )
g (t) − g (t′ ) = t − t′ , ∇g (αt + (1 − α)t′ )
| {z }
Taylor expansion
where ϵi,k are i.i.d. Rademacher variables, and fk (zi ) is the k-th component of f (zi ).
Maurer, A., 2016. A vector-contraction inequality for Rademacher complexities. In Algorithmic Learning
Theory.
Rademacher Complexity Bound (Optional)
Let ℓy be either the hinge loss or the soft-max loss
c
X 1
|ℓy (t) − ℓy (t′ )| ≤ 2 max |tj − tj′ | ≤ 2 (tj − tj′ )2 2 = 2∥t − t′ ∥2 .
j∈[c]
| {z } j=1
=∥t−t′ ∥∞
√
By the concavity of x 7→ x, we know
c c
h X X 2 21 i h X X 2 21 i
E ϵij xi ≤ E ϵij xi .
2 2
j=1 i∈[n] j=1 i∈[n]
c
h X X 2 12 i X 1
2
=⇒ E ϵij xi ≤ c· E∥xi ∥22 .
2
j=1 i∈[n] i∈[n]
Rademacher Complexity Bound
Generalization bound
Let A be ERM. Then with probability at least 1 − δ
1 1 √ √
F (A(S)) − F (w∗ ) ≲ n− 2 log 2 (2/δ) + c/ n.
In some cases, the projection may be hard to compute (or even approximate)
which optimizes a linear approximation to the function over the constraint set
This requires the set W to be bounded, otherwise there may be no solution
Frank-Wolfe Method
Frank-Wolfe Method
1: for t = 0, 1, . . . , T do
2: Set vt = arg minv∈W ⟨∇FS (wt ), v⟩
3: Set wt+1 = wt + γt (vt − wt ), where γt ∈ [0, 1]
wt+1 = (1 − γt )wt + γt vt
A choice of γt is γt = 2/(t + 2)
Linear Minimization Oracle
Let W be a convex, closed and bounded set. The linear minimization oracle of Ω
(IMOW ) returns a vector ĝ such that
Lasso Regression
min FS (w) = ∥Aw − b∥22 s.t. ∥w∥1 ≤ 1.
w
2LD 2
FS (wt ) − FS (w∗ ) ≤ , D := max ∥w − w′ ∥2 .
t +1 ′
w,w ∈W
L
FS (wt+1 ) ≤ FS (wt ) + ⟨wt+1 − wt , ∇FS (wt )⟩ + ∥wt+1 − wt ∥22
2
γ2L
= FS (wt ) + γt ⟨vt − wt , ∇FS (wt )⟩ + t ∥vt − wt ∥22
2
2
γ LD 2
≤ FS (wt ) + γt FS (w∗ ) − FS (wt ) + t
.
2
γ 2 LD 2
=⇒ FS (wt+1 ) − FS (w∗ ) ≤ (1 − γt ) FS (wt ) − FS (w∗ ) + t .
2
Convergence of Frank-Wolfe Method
γ 2 LD 2
We derived FS (wt+1 ) − FS (w∗ ) ≤ (1 − γt ) FS (wt ) − FS (w∗ ) + t .
2
t 4LD 2
=⇒ FS (wt+1 ) − FS (w∗ ) ≤ FS (wt ) − FS (w∗ ) +
t +2 2(t + 2)2
Telescoping shows
t
X
(t + 1)(t + 2)∆t+1 = (k + 1)(k + 2)∆k+1 − k(k + 1)∆t
k=0
t
X
≤ 2LD 2 = 2LD 2 (t + 1).
k=0
Stochastic Frank-Wolfe Method
Stochastic Frank-Wolfe Method
1: for t = 0, 1, . . . , T do
2: ˆ t , v⟩, where ∇
Set vt = arg minv∈W ⟨∇ ˆ t is an unbiased estimator of ∇FS (wt )
3: Set wt+1 = wt + γt (vt − wt ), where γt ∈ [0, 1]
Convergence Rate
Let W be nonempty, convex and bounded. Let FS be convex and L-smooth. Assume the
following variance condition
ˆ t − ∇FS (wt ) ≤ LD
2 2
E ∇ .
2 t +1
Then FW with γt = 2/(t + 2) satisfies
4LD 2
E FS (wt ) − F (w∗ ) ≤
.
t +1
ˆt = 1
X
∇ ∇f (wt ; zi ).
|St |
i∈St
ˆ t ] = 1 ESt
hX i
ESt [∇ ∇f (wt ; zi ) = ∇FS (wt ).
|St |
i∈St
2 σ2 2
ˆ t − ∇FS (wt )
E ∇ = , σ 2 := E[ ∇f (wt ; zit ) − ∇FS (wt ) 2 ].
2 |St |
Choosing |St | = σ 2 (t + 1)2 /(L2 D 2 ) meets the variance condition for stochastic FW.
Optimization for Multi-class Classification
Training Multi-class Model by Frank-Wolfe Method
Recall the problem
n
1X X
ℓyi (wj⊤ xi )j∈[c] ∥wj ∥22 ≤ B 2 .
min s.t.
w n i=1
j∈[c]
∗
which has a closed-form solution w = (w1∗ , . . . , wc∗ ) as follows
gj
wj∗ = − P 1 . (2)
c
˜
j=1 ∥gj˜∥22 2
Proof on Closed-form Solution (Optional)
gj
w∗ = arg min ⟨w, g⟩ ⇐⇒ wj∗ = − P 1 .
∥wj ∥22 ≤1 c
P
w: j∈[c] ˜
j=1 ∥gj˜∥22 2
and
c c c c
X X − 1 X X 1
⟨w∗ , g⟩ = ⟨wj∗ , gj ⟩ = − ∥gj˜∥22 2 ∥gj˜∥22 = − ∥gj˜∥22 2 = −∥g∥2,2 .
j=1 ˜
j=1 ˜
j=1 ˜
j=1
Summary