-
Notifications
You must be signed in to change notification settings - Fork 29.5k
Add GPTBigCode model (Optimized GPT2 with MQA from Santacoder & BigCode) #22575
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@lvwerra @harm-devries |
Code on the Hub is fine too and we are adding better support for it every day :-) |
Hi @sgugger, the next generation of the model will also support this architecture so there should also be significantly more usage. Discussed this also with @LysandreJik previously, what do you think? |
The documentation is not available anymore as the PR was closed or merged. |
if position_ids is None: | ||
position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device) | ||
position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could benefit from #21853
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure I understand that PR, is it to make generation independent of the padding? If so we definitely want it.
If you prefer @lvwerra and if the architecture is frozen: we won't be able to accommodate changes after it's merged and released in Transformers (no breaking changes in Transformers), whereas it's easier to quickly experiment with code on the Hub. If you feel the model is mature enough and it's time, I'm not opposed :-) |
zeros = torch.zeros(attn_view, dtype=query.dtype, device=query.device) | ||
attn_weights = torch.baddbmm(zeros, query, key, beta=1, alpha=scale_factor).view(attn_shape) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For learning purpose:
Note that this block before was:
attn_weights = torch.baddbmm(
torch.empty(attn_view), query, key, beta=0, alpha=scale_factor
).view(attn_shape)
This seemed to be needed to fix the CI tests that were failing on CPU, the reason behind that is the following:
orch.empty(attn_view)
creates a tensor of shape attn_view
that will also contain random values in the order of magnitude of 1e-43
. Even though the empty tensor is multiplied by beta (which is hardcoded to 0), it let to some overflows on CPU only, leading to the presence of nan
values inside attn_weights
. Hence the fix seemed to be
to create an empty tensor of zeros and multiply it with an arbitrary float value (here 1)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jlamypoirier mentioned that this would add some overhead on GPU, I will add a check to check if the model is running on cpu or not
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👀 I knew it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for adding this new model! My main comment is in the weight initialization.
|
||
*The BigCode project is an open-scientific collaboration working on the responsible development of large language models for code. This tech report describes the progress of the collaboration until December 2022, outlining the current state of the Personally Identifiable Information (PII) redaction pipeline, the experiments conducted to de-risk the model architecture, and the experiments investigating better preprocessing methods for the training data. We train 1.1B parameter models on the Java, JavaScript, and Python subsets of The Stack and evaluate them on the MultiPL-E text-to-code benchmark. We find that more aggressive filtering of near-duplicates can further boost performance and, surprisingly, that selecting files from repositories with 5+ GitHub stars deteriorates performance significantly. Our best model outperforms previous open-source multilingual code generation models (InCoder-6.7B and CodeGen-Multi-2.7B) in both left-to-right generation and infilling on the Java, JavaScript, and Python portions of MultiPL-E, despite being a substantially smaller model. All models are released under an OpenRAIL license at [this https URL.](https://huggingface.co/bigcode)* | ||
|
||
The model is a an optimized GPT2 model with support for Multi-Query Attention. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe link to the GPT-2 model doc page here?
# Copyright 2023 The OpenAI Team Authors. | ||
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure we need more than the BigCode team and Hugging Face here.
# Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme: | ||
# > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale | ||
# > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers. | ||
# > -- GPT-2 :: https://openai.com/blog/better-language-models/ | ||
# | ||
# Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py | ||
for name, p in module.named_parameters(): | ||
if name == "c_proj.weight": | ||
# Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block | ||
p.data.normal_(mean=0.0, std=(self.config.initializer_range / math.sqrt(2 * self.config.n_layer))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a bit ugly and won't necessarily work since the module is not marked as initialized, wo will get re-initialized as a linear layer after. This should be in a check for GptBigCodeAttention
, where you initialize module.c_proj.weight
this way then mark module.c_proj._is_hf_initialized=True
so that the layer is not reinitialized. See this example in OneFormer for instance.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This part is copied as-is from GPT2...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So let's fix GPT-2 too.
|
||
loss = None | ||
if labels is not None: | ||
# Shift so that tokens < n predict n |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Move the labels to the logits device here, for model parallelism support.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Again these are copies of GPT2, but I'm open to fixing
|
||
loss = None | ||
if labels is not None: | ||
if self.config.problem_type is None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here.
|
||
loss = None | ||
if labels is not None: | ||
loss_fct = CrossEntropyLoss() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here.
Thanks a lot for your feedback! Just addressed them all, |
Please wait a bit before merging, I'll do a final check for the latest changes |
if query.device == torch.device("cpu"): | ||
# this seemed to be needed - on CPU only: check https://github.com/huggingface/transformers/pull/22575/files#r1159858870 | ||
zeros = torch.zeros(attn_view, dtype=query.dtype, device=query.device) | ||
attn_weights = torch.baddbmm(zeros, query, key, beta=1, alpha=scale_factor).view(attn_shape) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I know it doesn't really matter, but shouldn't this also have beta=0?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tried it with beta=0 and still got the issue .. maybe torch complains if you multiply 0 tensor with 0 ..
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That should be the least problematic operation here...
The point is baddbmm shouldn't even read the input when beta=0. Also do you have some code to reproduce this? On my machine there is no problem
>>> a=torch.full([50, 50, 50], torch.nan)
>>> b=torch.randn([50, 50, 50])
>>> c=torch.randn([50, 50, 50])
>>> d=torch.full([50, 50, 50], 0)
>>> torch.baddbmm(a,b,c, beta=1, alpha=5).sum()
tensor(nan)
>>> torch.baddbmm(a,b,c, beta=0, alpha=5).sum()
tensor(-9915.6406)
>>> torch.baddbmm(d,b,c, beta=0, alpha=5).sum()
tensor(-9915.6406)
>>> torch.baddbmm(d,b,c, beta=1, alpha=5).sum()
tensor(-9915.6406)
>>> (5*torch.bmm(b,c)).sum()
tensor(-9915.6406)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Now running on a A100:
import torch
s=32
dtype=torch.float32
a=torch.full([s, s, s], torch.nan, dtype=dtype)
b=torch.randn([s, s, s], dtype=dtype)
c=torch.randn([s, s, s], dtype=dtype)
d=torch.zeros([s, s, s], dtype=dtype)
y0=torch.baddbmm(a,b,c, beta=1, alpha=5)
y1=torch.baddbmm(a,b,c, beta=0, alpha=5)
y2=torch.baddbmm(d,b,c, beta=0, alpha=5)
y3=torch.baddbmm(d,b,c, beta=1, alpha=5)
y4=torch.bmm(b,c)*5
y5=torch.matmul(b,c)*5
yy=[y0,y1,y2,y3,y4,y5]
aa=a.cuda()
bb=b.cuda()
cc=c.cuda()
dd=d.cuda()
z0=torch.baddbmm(aa,bb,cc, beta=1, alpha=5).cpu()
z1=torch.baddbmm(aa,bb,cc, beta=0, alpha=5).cpu()
z2=torch.baddbmm(dd,bb,cc, beta=0, alpha=5).cpu()
z3=torch.baddbmm(dd,bb,cc, beta=1, alpha=5).cpu()
z4=(torch.bmm(bb,cc)*5).cpu()
z5=(torch.matmul(bb,cc)*5).cpu()
zz=[z0,z1,z2,z3,z4,z5]
print([(z-y).std() for y,z in zip(yy,zz)])
print([(y-yy[1]).std() for y in yy])
print([(z-zz[1]).std() for z in zz])
>>> print([(z-y).std() for y,z in zip(yy,zz)])
[tensor(nan), tensor(0.0084), tensor(0.0084), tensor(0.0084), tensor(0.0084), tensor(0.0084)]
>>>
>>> print([(y-yy[1]).std() for y in yy])
[tensor(nan), tensor(0.), tensor(0.), tensor(0.), tensor(4.2393e-06), tensor(4.2393e-06)]
>>> print([(z-zz[1]).std() for z in zz])
[tensor(nan), tensor(0.), tensor(0.), tensor(0.), tensor(0.), tensor(0.)]
>>>
From this output it looks like the output is indeed different on CPU vs GPU, but that happens for every single kind of matrix multiplication so there is nothing we can do about it... In general I don't think we should expect the numerically equal outputs on different devices
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I could reproduce through a common test, what I did was to replace the following block with:
if False:
# this seemed to be needed - on CPU only: check https://github.com/huggingface/transformers/pull/22575/files#r1159858870
zeros = torch.zeros(attn_view, dtype=query.dtype, device=query.device)
attn_weights = torch.baddbmm(zeros, query, key, beta=1, alpha=scale_factor).view(attn_shape)
else:
# We do the standard operation on GPU for faster inference
attn_weights = torch.baddbmm(
torch.empty(attn_view, device=query.device, dtype=query.dtype), query, key, beta=0, alpha=scale_factor
).view(attn_shape)
if attn_weights.isnan().any():
print()
And run
CUDA_VISIBLE_DEVICES= pytest tests/models/gpt_bigcode/test_modeling_gpt_bigcode.py::GPTBigCodeMQAModelTest::test_beam_sample_generate
and put a breakpoint right on the print. Strangely I couldn't reproduce in a small snippet. Here is what I have tried:
import torch
N_EXPERIMENTS = 10000
device = torch.device("cpu")
batch_size=8
q_len=4
k_len=5
hidden_dim=8
dtype=torch.float32
attn_view =(batch_size, q_len, k_len)
attn_shape = (batch_size, 1, q_len, k_len)
for _ in range(N_EXPERIMENTS):
query = torch.randn(batch_size, q_len, hidden_dim, device=device, dtype=dtype)
key = torch.randn(batch_size, hidden_dim, k_len, device=device, dtype=dtype)
a = torch.empty(attn_view, device=query.device, dtype=query.dtype)
out = torch.baddbmm(a, query, key, beta=0, alpha=0.356).view(attn_shape)
if out.isnan().any():
print(out.isnan().any())
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I did a bit more investigating, from what I could find that's a semi-random bug that only manifests itself when it feels like it so it's not easy to reproduce. I was able to get it consistently with
import torch
s=7
dtype=torch.float32
aa=[]
n=[]
for i in range(10000):
a=torch.full([s,s,s],torch.nan, dtype=dtype)
b=torch.randn([s,s,s], dtype=dtype)
c=torch.randn([s,s,s], dtype=dtype)
y=torch.baddbmm(a,b,c, beta=0, alpha=5)
aa+=[a,b,c]
n.append(y.isnan().float().mean().item())
>>> print(torch.mean(torch.tensor(n)).item())
0.2447994202375412
For some reason I never get any nan with s>=8 and rarely when not accumulating tensors in the list (aka reusing memory addreses).
Edit: interestingly it can make nans with beta=0 even is the input doesn't have any. It's way less common but seems enough to break the tests.
import torch
s=7
dtype=torch.float32
aa=[]
n=[]
for i in range(10000):
a=torch.zeros([s,s,s], dtype=dtype)
b=torch.randn([s,s,s], dtype=dtype)
c=torch.randn([s,s,s], dtype=dtype)
y=torch.baddbmm(a,b,c, beta=0, alpha=5)
aa+=[a,b,c]
n.append(y.isnan().float().mean().item())
>>> print(torch.mean(torch.tensor(n)).item())
1.1661808230201132e-06
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Anyway, it looks like that's a known bug pytorch/pytorch#96037. It's been fixed in pytorch/pytorch#96086 but only for the next release of pytorch. So for now we should leave the zero, I'll just simplify that code and add a reference to the issue in the comment. (In the future it could be updated to only set to zero for torch version <=2.0.0)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a mile for investigating!
GPTBigCodeForCausalLM, | ||
GPTBigCodeForSequenceClassification, | ||
GPTBigCodeForTokenClassification, | ||
GPTBigCodeModel, | ||
) | ||
from transformers.models.gpt_bigcode.modeling_gpt_bigcode import GPTBigCodeAttention | ||
|
||
torch.backends.cuda.matmul.allow_tf32 = False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why??? If needed that would need a comment and it can't go in an import.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The tests pass on A100 witout it so I'll just remove
I did a few minor tweaks, I'm OK for merging if it works for everyone. (Assuming CI passes) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking great for me as well! Thanks a lot for all your work and investigations @jlamypoirier 🔥 !
Thanks again for everything!
…de) (huggingface#22575) * Add model with cli tool * Remove unwanted stuff * Add new code * Remove inference runner * Style * Fix checks * Test updates * make fixup * fix docs * fix doc * fix test * hopefully fix pipeline tests * refactor * fix CIs * add comment * rename to `GPTBigCodeForCausalLM` * correct readme * make fixup + docs * make fixup * fixes * fixes * Remove pruning * Remove import * Doc updates * More pruning removal * Combine copies * Single MQA implementation, remove kv cache pre-allocation and padding * Update doc * Revert refactor to match gpt2 style * Merge back key and value caches, fix some type hints * Update doc * Fix position ids pith padding (PR 21080) * Add conversion script temporarily * Update conversion script * Remove checkpoint conversion * New model * Fix MQA test * Fix copies * try fix tests * FIX TEST!! * remove `DoubleHeadsModel` * add MQA tests * add slow tests * clean up * add CPU checker * final fixes * fixes - fix GPU issue - fixed slow tests - skip disk offload * fix final issue * Simplify and comment baddbmm fix * Remove unnecessary code * Transpose tweaks * Use beta=1 on cpu, improve tests --------- Co-authored-by: younesbelkada <[email protected]>
any updates on supporting flash attention ? or do we have a different PR to track it |
cc @younesbelkada I think this is supported in BetterTransformers no? |
Indeed this should go into |
The GPTBigcode model from BigCode. It is the same model as GPT2, with:
Other than MQA, it's the same model as GPT2, just a new implementation (though it's not numerically equivalent and the checkpoints are not compatible)
The optimizations (I might be missing some):
gelu_pytorch_tanh
(see Add the pytorch implementation of the OpenAI GeLU approximation #21344 Add the GeLU activation from pytorch with the tanh approximation #21345)_attn
and_upcast_and_reordered_attn
. Always merge the matmul with scaling. Renamereorder_and_upcast_attn
->attention_softmax_in_fp32
scale_attn_by_inverse_layer_idx
->scale_attention_softmax_in_fp32
and change its behavior to match Megatron-LM (divide by layer_idx in fp16, then multiply in fp32).layer_past
/present
, does it risk creating problems?)Excluded from this PR (optional/opt-in features, could be added later):
TODO: