Skip to content

Add GPTBigCode model (Optimized GPT2 with MQA from Santacoder & BigCode) #22575

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 54 commits into from
Apr 10, 2023

Conversation

jlamypoirier
Copy link
Contributor

@jlamypoirier jlamypoirier commented Apr 4, 2023

The GPTBigcode model from BigCode. It is the same model as GPT2, with:

Other than MQA, it's the same model as GPT2, just a new implementation (though it's not numerically equivalent and the checkpoints are not compatible)

The optimizations (I might be missing some):

  • Use gelu_pytorch_tanh (see Add the pytorch implementation of the OpenAI GeLU approximation #21344 Add the GeLU activation from pytorch with the tanh approximation #21345)
  • Avoid unnecessary synchronizations (added to GPT2 in Change constant torch.tensor to torch.full #20061, but wasn't in the original santacoder).
  • Use Linear layers instead of Conv1D (good speedup but makes the checkpoints incompatible).
  • Merge _attn and _upcast_and_reordered_attn. Always merge the matmul with scaling. Rename reorder_and_upcast_attn->attention_softmax_in_fp32
  • Rename scale_attn_by_inverse_layer_idx-> scale_attention_softmax_in_fp32 and change its behavior to match Megatron-LM (divide by layer_idx in fp16, then multiply in fp32).
  • Cache the attention mask value to avoid recreating it every time.
  • Use jit to fuse the attention fp32 casting, masking, softmax, and scaling.
  • Combine the attention and causal masks into a single one, pre-computed for the whole model instead of every layer.
  • Merge the key and value caches into one (this changes the format of layer_past/ present, does it risk creating problems?)
  • Use the memory layout (self.num_heads, 3, self.head_dim) instead of (3, self.num_heads, self.head_dim) for the QKV tensor with MHA. (prevents an overhead with the merged key and values, but makes the checkpoints incompatible).

Excluded from this PR (optional/opt-in features, could be added later):

  • CPU optimization for inference, aka InferenceRunner (huge speedup for generation with pre-allocated tensors, pre-computed views and support; faster than Deepspeed, but too experimental to add now)
  • KV cache pre-allocation and padding. (Same reason)
  • MQA with separate Q and KV (MQA2 in bigcode, a bit faster for training , slower for inference)
  • FlashAttention (planning to add support in near future)
  • Conversion script for Megatron weights (the MQA part needs the BigCode fork of Megatron)

TODO:

  • Update/fix the tests
  • Update the docs (should be mostly ok by now)
  • Address the remaining circleci issues (mostly related to the tests)

@jlamypoirier
Copy link
Contributor Author

@lvwerra @harm-devries
(Replaces #21253)

@jlamypoirier jlamypoirier changed the title Add GPTBigCode model Add GPTBigCode model (Optimized GPT2 with MQA from Santacoder & BigCode models) Apr 4, 2023
@jlamypoirier jlamypoirier changed the title Add GPTBigCode model (Optimized GPT2 with MQA from Santacoder & BigCode models) Add GPTBigCode model (Optimized GPT2 with MQA from Santacoder & BigCode) Apr 4, 2023
@sgugger
Copy link
Collaborator

sgugger commented Apr 4, 2023

Code on the Hub is fine too and we are adding better support for it every day :-)

@lvwerra
Copy link
Member

lvwerra commented Apr 5, 2023

Hi @sgugger, the next generation of the model will also support this architecture so there should also be significantly more usage. Discussed this also with @LysandreJik previously, what do you think?

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Apr 5, 2023

The documentation is not available anymore as the PR was closed or merged.

Comment on lines 730 to 732
if position_ids is None:
position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could benefit from #21853

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure I understand that PR, is it to make generation independent of the padding? If so we definitely want it.

@sgugger
Copy link
Collaborator

sgugger commented Apr 5, 2023

If you prefer @lvwerra and if the architecture is frozen: we won't be able to accommodate changes after it's merged and released in Transformers (no breaking changes in Transformers), whereas it's easier to quickly experiment with code on the Hub. If you feel the model is mature enough and it's time, I'm not opposed :-)

Comment on lines 164 to 165
zeros = torch.zeros(attn_view, dtype=query.dtype, device=query.device)
attn_weights = torch.baddbmm(zeros, query, key, beta=1, alpha=scale_factor).view(attn_shape)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For learning purpose:

Note that this block before was:

attn_weights = torch.baddbmm(
   torch.empty(attn_view), query, key, beta=0, alpha=scale_factor
).view(attn_shape)

This seemed to be needed to fix the CI tests that were failing on CPU, the reason behind that is the following:
orch.empty(attn_view) creates a tensor of shape attn_view that will also contain random values in the order of magnitude of 1e-43. Even though the empty tensor is multiplied by beta (which is hardcoded to 0), it let to some overflows on CPU only, leading to the presence of nan values inside attn_weights. Hence the fix seemed to be
to create an empty tensor of zeros and multiply it with an arbitrary float value (here 1)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jlamypoirier mentioned that this would add some overhead on GPU, I will add a check to check if the model is running on cpu or not

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👀 I knew it

@younesbelkada younesbelkada marked this pull request as ready for review April 6, 2023 14:37
@younesbelkada younesbelkada requested a review from sgugger April 6, 2023 14:43
Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for adding this new model! My main comment is in the weight initialization.


*The BigCode project is an open-scientific collaboration working on the responsible development of large language models for code. This tech report describes the progress of the collaboration until December 2022, outlining the current state of the Personally Identifiable Information (PII) redaction pipeline, the experiments conducted to de-risk the model architecture, and the experiments investigating better preprocessing methods for the training data. We train 1.1B parameter models on the Java, JavaScript, and Python subsets of The Stack and evaluate them on the MultiPL-E text-to-code benchmark. We find that more aggressive filtering of near-duplicates can further boost performance and, surprisingly, that selecting files from repositories with 5+ GitHub stars deteriorates performance significantly. Our best model outperforms previous open-source multilingual code generation models (InCoder-6.7B and CodeGen-Multi-2.7B) in both left-to-right generation and infilling on the Java, JavaScript, and Python portions of MultiPL-E, despite being a substantially smaller model. All models are released under an OpenRAIL license at [this https URL.](https://huggingface.co/bigcode)*

The model is a an optimized GPT2 model with support for Multi-Query Attention.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe link to the GPT-2 model doc page here?

Comment on lines 3 to 4
# Copyright 2023 The OpenAI Team Authors.
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure we need more than the BigCode team and Hugging Face here.

Comment on lines 397 to 406
# Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
# > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
# > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
# > -- GPT-2 :: https://openai.com/blog/better-language-models/
#
# Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
for name, p in module.named_parameters():
if name == "c_proj.weight":
# Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
p.data.normal_(mean=0.0, std=(self.config.initializer_range / math.sqrt(2 * self.config.n_layer)))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a bit ugly and won't necessarily work since the module is not marked as initialized, wo will get re-initialized as a linear layer after. This should be in a check for GptBigCodeAttention, where you initialize module.c_proj.weight this way then mark module.c_proj._is_hf_initialized=True so that the layer is not reinitialized. See this example in OneFormer for instance.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This part is copied as-is from GPT2...

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So let's fix GPT-2 too.


loss = None
if labels is not None:
# Shift so that tokens < n predict n
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Move the labels to the logits device here, for model parallelism support.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again these are copies of GPT2, but I'm open to fixing


loss = None
if labels is not None:
if self.config.problem_type is None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here.


loss = None
if labels is not None:
loss_fct = CrossEntropyLoss()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here.

- fix GPU issue
- fixed slow tests
- skip disk offload
@younesbelkada younesbelkada requested a review from sgugger April 6, 2023 17:36
@younesbelkada
Copy link
Contributor

younesbelkada commented Apr 6, 2023

Thanks a lot for your feedback! Just addressed them all,
Small note that the cpu/disk offload seem to not work on the testing suite, but I think it is related to the corner case issues we faced with tiny T5 models, as the test pass for the GPTBigCodModelTest but does not pass for the GPTBigCodeMQAModelTest.
I will also make sure doctests pass before merging

@jlamypoirier
Copy link
Contributor Author

Please wait a bit before merging, I'll do a final check for the latest changes

if query.device == torch.device("cpu"):
# this seemed to be needed - on CPU only: check https://github.com/huggingface/transformers/pull/22575/files#r1159858870
zeros = torch.zeros(attn_view, dtype=query.dtype, device=query.device)
attn_weights = torch.baddbmm(zeros, query, key, beta=1, alpha=scale_factor).view(attn_shape)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know it doesn't really matter, but shouldn't this also have beta=0?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried it with beta=0 and still got the issue .. maybe torch complains if you multiply 0 tensor with 0 ..

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That should be the least problematic operation here...
The point is baddbmm shouldn't even read the input when beta=0. Also do you have some code to reproduce this? On my machine there is no problem

>>> a=torch.full([50, 50, 50], torch.nan)
>>> b=torch.randn([50, 50, 50])
>>> c=torch.randn([50, 50, 50])
>>> d=torch.full([50, 50, 50], 0)
>>> torch.baddbmm(a,b,c, beta=1, alpha=5).sum()
tensor(nan)
>>> torch.baddbmm(a,b,c, beta=0, alpha=5).sum()
tensor(-9915.6406)
>>> torch.baddbmm(d,b,c, beta=0, alpha=5).sum()
tensor(-9915.6406)
>>> torch.baddbmm(d,b,c, beta=1, alpha=5).sum()
tensor(-9915.6406)
>>> (5*torch.bmm(b,c)).sum()
tensor(-9915.6406)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now running on a A100:


import torch

s=32
dtype=torch.float32


a=torch.full([s, s, s], torch.nan, dtype=dtype)
b=torch.randn([s, s, s], dtype=dtype)
c=torch.randn([s, s, s], dtype=dtype)
d=torch.zeros([s, s, s], dtype=dtype)

y0=torch.baddbmm(a,b,c, beta=1, alpha=5)
y1=torch.baddbmm(a,b,c, beta=0, alpha=5)
y2=torch.baddbmm(d,b,c, beta=0, alpha=5)
y3=torch.baddbmm(d,b,c, beta=1, alpha=5)
y4=torch.bmm(b,c)*5
y5=torch.matmul(b,c)*5
yy=[y0,y1,y2,y3,y4,y5]

aa=a.cuda()
bb=b.cuda()
cc=c.cuda()
dd=d.cuda()

z0=torch.baddbmm(aa,bb,cc, beta=1, alpha=5).cpu()
z1=torch.baddbmm(aa,bb,cc, beta=0, alpha=5).cpu()
z2=torch.baddbmm(dd,bb,cc, beta=0, alpha=5).cpu()
z3=torch.baddbmm(dd,bb,cc, beta=1, alpha=5).cpu()
z4=(torch.bmm(bb,cc)*5).cpu()
z5=(torch.matmul(bb,cc)*5).cpu()
zz=[z0,z1,z2,z3,z4,z5]

print([(z-y).std() for y,z in zip(yy,zz)])

print([(y-yy[1]).std() for y in yy])
print([(z-zz[1]).std() for z in zz])
>>> print([(z-y).std() for y,z in zip(yy,zz)])
[tensor(nan), tensor(0.0084), tensor(0.0084), tensor(0.0084), tensor(0.0084), tensor(0.0084)]
>>> 
>>> print([(y-yy[1]).std() for y in yy])
[tensor(nan), tensor(0.), tensor(0.), tensor(0.), tensor(4.2393e-06), tensor(4.2393e-06)]
>>> print([(z-zz[1]).std() for z in zz])
[tensor(nan), tensor(0.), tensor(0.), tensor(0.), tensor(0.), tensor(0.)]
>>>  

From this output it looks like the output is indeed different on CPU vs GPU, but that happens for every single kind of matrix multiplication so there is nothing we can do about it... In general I don't think we should expect the numerically equal outputs on different devices

Copy link
Contributor

@younesbelkada younesbelkada Apr 7, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I could reproduce through a common test, what I did was to replace the following block with:

if False:
     # this seemed to be needed - on CPU only: check https://github.com/huggingface/transformers/pull/22575/files#r1159858870
    zeros = torch.zeros(attn_view, dtype=query.dtype, device=query.device)
    attn_weights = torch.baddbmm(zeros, query, key, beta=1, alpha=scale_factor).view(attn_shape)
 else:
    # We do the standard operation on GPU for faster inference
    attn_weights = torch.baddbmm(
        torch.empty(attn_view, device=query.device, dtype=query.dtype), query, key, beta=0, alpha=scale_factor
    ).view(attn_shape)

    if attn_weights.isnan().any():
        print()

And run

CUDA_VISIBLE_DEVICES= pytest tests/models/gpt_bigcode/test_modeling_gpt_bigcode.py::GPTBigCodeMQAModelTest::test_beam_sample_generate

and put a breakpoint right on the print. Strangely I couldn't reproduce in a small snippet. Here is what I have tried:

import torch

N_EXPERIMENTS = 10000
device = torch.device("cpu")

batch_size=8
q_len=4
k_len=5
hidden_dim=8
dtype=torch.float32

attn_view =(batch_size, q_len, k_len)
attn_shape = (batch_size, 1, q_len, k_len)

for _ in range(N_EXPERIMENTS):
    query = torch.randn(batch_size, q_len, hidden_dim, device=device, dtype=dtype)
    key = torch.randn(batch_size, hidden_dim, k_len, device=device, dtype=dtype)

    a = torch.empty(attn_view, device=query.device, dtype=query.dtype)
    out = torch.baddbmm(a, query, key, beta=0, alpha=0.356).view(attn_shape)

    if out.isnan().any():
        print(out.isnan().any())

Copy link
Contributor Author

@jlamypoirier jlamypoirier Apr 7, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did a bit more investigating, from what I could find that's a semi-random bug that only manifests itself when it feels like it so it's not easy to reproduce. I was able to get it consistently with

import torch
s=7
dtype=torch.float32
aa=[]

n=[]
for i in range(10000):
    a=torch.full([s,s,s],torch.nan, dtype=dtype)
    b=torch.randn([s,s,s], dtype=dtype)
    c=torch.randn([s,s,s], dtype=dtype)
    y=torch.baddbmm(a,b,c, beta=0, alpha=5)
    aa+=[a,b,c]
    n.append(y.isnan().float().mean().item())

>>> print(torch.mean(torch.tensor(n)).item())
0.2447994202375412

For some reason I never get any nan with s>=8 and rarely when not accumulating tensors in the list (aka reusing memory addreses).

Edit: interestingly it can make nans with beta=0 even is the input doesn't have any. It's way less common but seems enough to break the tests.

import torch
s=7
dtype=torch.float32
aa=[]

n=[]
for i in range(10000):
    a=torch.zeros([s,s,s], dtype=dtype)
    b=torch.randn([s,s,s], dtype=dtype)
    c=torch.randn([s,s,s], dtype=dtype)
    y=torch.baddbmm(a,b,c, beta=0, alpha=5)
    aa+=[a,b,c]
    n.append(y.isnan().float().mean().item())

>>> print(torch.mean(torch.tensor(n)).item())
1.1661808230201132e-06

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Anyway, it looks like that's a known bug pytorch/pytorch#96037. It's been fixed in pytorch/pytorch#96086 but only for the next release of pytorch. So for now we should leave the zero, I'll just simplify that code and add a reference to the issue in the comment. (In the future it could be updated to only set to zero for torch version <=2.0.0)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a mile for investigating!

GPTBigCodeForCausalLM,
GPTBigCodeForSequenceClassification,
GPTBigCodeForTokenClassification,
GPTBigCodeModel,
)
from transformers.models.gpt_bigcode.modeling_gpt_bigcode import GPTBigCodeAttention

torch.backends.cuda.matmul.allow_tf32 = False
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why??? If needed that would need a comment and it can't go in an import.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The tests pass on A100 witout it so I'll just remove

@jlamypoirier
Copy link
Contributor Author

I did a few minor tweaks, I'm OK for merging if it works for everyone. (Assuming CI passes)

Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking great for me as well! Thanks a lot for all your work and investigations @jlamypoirier 🔥 !
Thanks again for everything!

@younesbelkada younesbelkada merged commit e0921c6 into huggingface:main Apr 10, 2023
novice03 pushed a commit to novice03/transformers that referenced this pull request Jun 23, 2023
…de) (huggingface#22575)

* Add model with cli tool

* Remove unwanted stuff

* Add new code

* Remove inference runner

* Style

* Fix checks

* Test updates

* make fixup

* fix docs

* fix doc

* fix test

* hopefully fix pipeline tests

* refactor

* fix CIs

* add comment

* rename to `GPTBigCodeForCausalLM`

* correct readme

* make fixup + docs

* make fixup

* fixes

* fixes

* Remove pruning

* Remove import

* Doc updates

* More pruning removal

* Combine copies

* Single MQA implementation, remove kv cache pre-allocation and padding

* Update doc

* Revert refactor to match gpt2 style

* Merge back key and value caches, fix some type hints

* Update doc

* Fix position ids pith padding (PR 21080)

* Add conversion script temporarily

* Update conversion script

* Remove checkpoint conversion

* New model

* Fix MQA test

* Fix copies

* try fix tests

* FIX TEST!!

* remove  `DoubleHeadsModel`

* add MQA tests

* add slow tests

* clean up

* add CPU checker

* final fixes

* fixes

- fix GPU issue
- fixed slow tests
- skip disk offload

* fix final issue

* Simplify and comment baddbmm fix

* Remove unnecessary code

* Transpose tweaks

* Use beta=1 on cpu, improve tests

---------

Co-authored-by: younesbelkada <[email protected]>
@bharadwajymg
Copy link

any updates on supporting flash attention ? or do we have a different PR to track it

@ArthurZucker
Copy link
Collaborator

cc @younesbelkada I think this is supported in BetterTransformers no?

@younesbelkada
Copy link
Contributor

Indeed this should go into BetterTransformer API on optimum library: https://github.com/huggingface/optimum
Once the feature is added there, you can just call model.to_bettertransformer() and benefit from flash-attention backend. @bharadwajymg would you mind opening a ticket there and request for BetterTransformer support for GPTBigCode model ? thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants