Skip to content

How to calculate the cpu memory required for DeepSpeedZeRoOffload initialization? #3606

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
treya-lin opened this issue May 25, 2023 · 2 comments
Assignees
Labels
bug Something isn't working training

Comments

@treya-lin
Copy link

treya-lin commented May 25, 2023

Hi I am using deepspeed for training, but it threw an error saying there is no enough CPU memoryy for initialization.

Here is the command I ran:
(The data (pt file) in dataset_path is just a small one generated from 100 lines of text. I just wanna see if the training can get started properly in my environment.)

deepspeed pretrain.py --deepspeed --deepspeed_config models/deepspeed_zero3_config.json --enable_zero3 \
 --pretrained_model_path ${modeldir}/chinese_llama_7b.bin \
 --dataset_path ${datadir}/chinese_llama_small_test1.pt \
 --spm_model_path ${modeldir}/tokenizer.model \
 --config_path models/llama/7b_config.json \
 --output_model_path ${outdir}/chinese_llama_7b_pretrain_small_test \
 --world_size 1 --data_processor lm --deepspeed_checkpoint_activations \
 --total_steps 10000 --save_checkpoint_steps 5000 --batch_size 24

Here is the config for zero_optimization:

  "zero_optimization": {
      "stage": 3,
      "offload_param": {
          "device": "cpu",
          "pin_memory": true
      },
      "offload_optimizer": {
          "device": "cpu",
          "pin_memory":true
      }
  },

Here is the full log.

[2023-05-24 10:00:03,687] [WARNING] [runner.py:186:fetch_hostfile] Unable to find hostfile, will proceed with training with local resources only.
[2023-05-24 10:00:03,760] [INFO] [runner.py:550:main] cmd = /opt/conda/bin/python -u -m deepspeed.launcher.launch --world_info=eyJsb2NhbGhvc3QiOiBbMF19 --master_addr=127.0.0.1 --master_port=29500 --enable_each_rank_log=None pretrain.py --deepspeed --deepspeed_config models/deepspeed_zero3_config.json --enable_zero3 --pretrained_model_path /workspace/TencentPretrain/../models/llama_chinese/chinese_llama_7b.bin --dataset_path /workspace/TencentPretrain/../datasets/preprocessed/chinese_llama_small_test1.pt --spm_model_path /workspace/TencentPretrain/../models/llama_chinese/tokenizer.model --config_path models/llama/7b_config.json --output_model_path /workspace/TencentPretrain/../models/output/chinese_llama_7b_pretrain_small_test --world_size 1 --data_processor lm --deepspeed_checkpoint_activations --total_steps 10000 --save_checkpoint_steps 5000 --batch_size 24
[2023-05-24 10:00:05,026] [INFO] [launch.py:135:main] 0 NV_LIBNCCL_DEV_PACKAGE=libnccl-dev=2.9.9-1+cuda11.3
[2023-05-24 10:00:05,026] [INFO] [launch.py:135:main] 0 NCCL_VERSION=2.9.9-1
[2023-05-24 10:00:05,026] [INFO] [launch.py:135:main] 0 NV_LIBNCCL_PACKAGE_VERSION=2.9.9-1
[2023-05-24 10:00:05,026] [INFO] [launch.py:135:main] 0 NV_LIBNCCL_PACKAGE=libnccl2=2.9.9-1+cuda11.3
[2023-05-24 10:00:05,026] [INFO] [launch.py:135:main] 0 NV_LIBNCCL_DEV_PACKAGE_NAME=libnccl-dev
[2023-05-24 10:00:05,026] [INFO] [launch.py:135:main] 0 NV_LIBNCCL_PACKAGE_NAME=libnccl2
[2023-05-24 10:00:05,026] [INFO] [launch.py:135:main] 0 NV_LIBNCCL_DEV_PACKAGE_VERSION=2.9.9-1
[2023-05-24 10:00:05,026] [INFO] [launch.py:142:main] WORLD INFO DICT: {'localhost': [0]}
[2023-05-24 10:00:05,026] [INFO] [launch.py:149:main] nnodes=1, num_local_procs=1, node_rank=0
[2023-05-24 10:00:05,026] [INFO] [launch.py:161:main] global_rank_mapping=defaultdict(<class 'list'>, {'localhost': [0]})
[2023-05-24 10:00:05,026] [INFO] [launch.py:162:main] dist_world_size=1
[2023-05-24 10:00:05,026] [INFO] [launch.py:164:main] Setting CUDA_VISIBLE_DEVICES=0
[2023-05-24 10:00:06,577] [INFO] [comm.py:654:init_distributed] Initializing TorchBackend in DeepSpeed with backend nccl
[2023-05-24 10:00:25,630] [INFO] [partition_parameters.py:416:__exit__] finished initializing model with 6.74B parameters
^[[21~[2023-05-24 10:01:55,229] [WARNING] [cpu_adam.py:86:__init__] FP16 params for CPUAdam may not work on AMD CPUs
Using /root/.cache/torch_extensions/py37_cu113 as PyTorch extensions root...
Detected CUDA files, patching ldflags
Emitting ninja build file /root/.cache/torch_extensions/py37_cu113/cpu_adam/build.ninja...
Building extension module cpu_adam...
Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)
ninja: no work to do.
Loading extension module cpu_adam...
Time to load cpu_adam op: 2.713547706604004 seconds
Adam Optimizer #0 is created with AVX2 arithmetic capability.
Config: alpha=0.000020, betas=(0.900000, 0.999000), weight_decay=0.000000, adam_w=1
[2023-05-24 10:02:10,518] [INFO] [logging.py:93:log_dist] [Rank 0] DeepSpeed info: version=0.8.3, git-hash=unknown, git-branch=unknown
[2023-05-24 10:02:10,532 INFO] Added key: store_based_barrier_key:2 to store for rank: 0
[2023-05-24 10:02:10,532 INFO] Rank 0: Completed store-based barrier for key:store_based_barrier_key:2 with 1 nodes.
[2023-05-24 10:02:10,533] [INFO] [logging.py:93:log_dist] [Rank 0] DeepSpeed Flops Profiler Enabled: True
[2023-05-24 10:02:10,534] [INFO] [logging.py:93:log_dist] [Rank 0] Removing param_group that has no 'params' in the client Optimizer
[2023-05-24 10:02:10,534] [INFO] [logging.py:93:log_dist] [Rank 0] Using client Optimizer as basic optimizer
[2023-05-24 10:02:10,546] [INFO] [logging.py:93:log_dist] [Rank 0] DeepSpeed Basic Optimizer = DeepSpeedCPUAdam
[2023-05-24 10:02:10,546] [INFO] [utils.py:56:is_zero_supported_optimizer] Checking ZeRO support for optimizer=DeepSpeedCPUAdam type=<class 'deepspeed.ops.adam.cpu_adam.DeepSpeedCPUAdam'>
[2023-05-24 10:02:10,546] [INFO] [logging.py:93:log_dist] [Rank 0] Creating torch.float16 ZeRO stage 3 optimizer
[2023-05-24 10:02:10,614] [INFO] [utils.py:829:see_memory_usage] Stage 3 initialize beginning
[2023-05-24 10:02:10,614] [INFO] [utils.py:834:see_memory_usage] MA 0.0 GB         Max_MA 0.73 GB         CA 0.73 GB         Max_CA 1 GB 
[2023-05-24 10:02:10,615] [INFO] [utils.py:839:see_memory_usage] CPU Virtual Memory:  used = 23.44 GB, percent = 24.8%
[2023-05-24 10:02:10,617] [INFO] [stage3.py:113:__init__] Reduce bucket size 500000000
[2023-05-24 10:02:10,617] [INFO] [stage3.py:114:__init__] Prefetch bucket size 50000000
Using /root/.cache/torch_extensions/py37_cu113 as PyTorch extensions root...
Emitting ninja build file /root/.cache/torch_extensions/py37_cu113/utils/build.ninja...
Building extension module utils...
Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)
ninja: no work to do.
Loading extension module utils...
Time to load utils op: 0.2768416404724121 seconds
[2023-05-24 10:02:10,934] [INFO] [utils.py:829:see_memory_usage] DeepSpeedZeRoOffload initialize [begin]
[2023-05-24 10:02:10,935] [INFO] [utils.py:834:see_memory_usage] MA 0.0 GB         Max_MA 0.0 GB         CA 0.73 GB         Max_CA 1 GB 
[2023-05-24 10:02:10,935] [INFO] [utils.py:839:see_memory_usage] CPU Virtual Memory:  used = 23.44 GB, percent = 24.8%
Parameter Offload: Total persistent parameters: 266240 in 65 params
[2023-05-24 10:02:10,993] [INFO] [utils.py:829:see_memory_usage] DeepSpeedZeRoOffload initialize [end]
[2023-05-24 10:02:10,994] [INFO] [utils.py:834:see_memory_usage] MA 0.0 GB         Max_MA 0.0 GB         CA 0.73 GB         Max_CA 1 GB 
[2023-05-24 10:02:10,994] [INFO] [utils.py:839:see_memory_usage] CPU Virtual Memory:  used = 23.44 GB, percent = 24.8%
[2023-05-24 10:02:11,033] [INFO] [utils.py:829:see_memory_usage] Before creating fp16 partitions
[2023-05-24 10:02:11,034] [INFO] [utils.py:834:see_memory_usage] MA 0.0 GB         Max_MA 0.0 GB         CA 0.73 GB         Max_CA 1 GB 
[2023-05-24 10:02:11,034] [INFO] [utils.py:839:see_memory_usage] CPU Virtual Memory:  used = 23.44 GB, percent = 24.8%
[2023-05-24 10:02:16,329] [INFO] [utils.py:829:see_memory_usage] After creating fp16 partitions: 7
[2023-05-24 10:02:16,330] [INFO] [utils.py:834:see_memory_usage] MA 0.0 GB         Max_MA 0.0 GB         CA 0.73 GB         Max_CA 1 GB 
[2023-05-24 10:02:16,330] [INFO] [utils.py:839:see_memory_usage] CPU Virtual Memory:  used = 39.73 GB, percent = 42.1%
[2023-05-24 10:02:16,369] [INFO] [utils.py:829:see_memory_usage] Before creating fp32 partitions
[2023-05-24 10:02:16,370] [INFO] [utils.py:834:see_memory_usage] MA 0.0 GB         Max_MA 0.0 GB         CA 0.73 GB         Max_CA 1 GB 
[2023-05-24 10:02:16,370] [INFO] [utils.py:839:see_memory_usage] CPU Virtual Memory:  used = 39.73 GB, percent = 42.1%
[2023-05-24 10:02:19,337] [INFO] [utils.py:829:see_memory_usage] After creating fp32 partitions
[2023-05-24 10:02:19,337] [INFO] [utils.py:834:see_memory_usage] MA 0.0 GB         Max_MA 0.0 GB         CA 0.73 GB         Max_CA 1 GB 
[2023-05-24 10:02:19,337] [INFO] [utils.py:839:see_memory_usage] CPU Virtual Memory:  used = 64.88 GB, percent = 68.7%
[2023-05-24 10:02:19,376] [INFO] [utils.py:829:see_memory_usage] Before initializing optimizer states
[2023-05-24 10:02:19,377] [INFO] [utils.py:834:see_memory_usage] MA 0.0 GB         Max_MA 0.0 GB         CA 0.73 GB         Max_CA 1 GB 
[2023-05-24 10:02:19,377] [INFO] [utils.py:839:see_memory_usage] CPU Virtual Memory:  used = 64.88 GB, percent = 68.7%
Traceback (most recent call last):
  File "pretrain.py", line 134, in <module>
    main()
  File "pretrain.py", line 130, in main
    trainer.train_and_validate(args)
  File "/workspace/TencentPretrain/tencentpretrain/trainer.py", line 79, in train_and_validate
    worker(args.local_rank, None, args, model_for_training, model_for_dataloader)
  File "/workspace/TencentPretrain/tencentpretrain/trainer.py", line 634, in worker
    dist_init_required=False)
  File "/opt/conda/lib/python3.7/site-packages/deepspeed/__init__.py", line 135, in initialize
    config_params=config_params)
  File "/opt/conda/lib/python3.7/site-packages/deepspeed/runtime/engine.py", line 340, in __init__
    self._configure_optimizer(optimizer, model_parameters)
  File "/opt/conda/lib/python3.7/site-packages/deepspeed/runtime/engine.py", line 1298, in _configure_optimizer
    self.optimizer = self._configure_zero_optimizer(basic_optimizer)
  File "/opt/conda/lib/python3.7/site-packages/deepspeed/runtime/engine.py", line 1626, in _configure_zero_optimizer
    communication_data_type=self.communication_data_type)
  File "/opt/conda/lib/python3.7/site-packages/deepspeed/runtime/zero/stage3.py", line 312, in __init__
    self._setup_for_real_optimizer()
  File "/opt/conda/lib/python3.7/site-packages/deepspeed/runtime/zero/stage3.py", line 371, in _setup_for_real_optimizer
    self.initialize_optimizer_states()
  File "/opt/conda/lib/python3.7/site-packages/deepspeed/runtime/zero/stage3.py", line 926, in initialize_optimizer_states
    device=self.device)
RuntimeError: [enforce fail at alloc_cpu.cpp:66] . DefaultCPUAllocator: can't allocate memory: you tried to allocate 4047667200 bytes. Error code 12 (Cannot allocate memory)
[2023-05-24 10:02:31,188] [INFO] [launch.py:318:sigkill_handler] Killing subprocess 228
[2023-05-24 10:02:31,189] [ERROR] [launch.py:324:sigkill_handler] ['/opt/conda/bin/python', '-u', 'pretrain.py', '--local_rank=0', '--deepspeed', '--deepspeed_config', 'models/deepspeed_zero3_config.json', '--enable_zero3', '--pretrained_model_path', '/workspace/TencentPretrain/../models/llama_chinese/chinese_llama_7b.bin', '--dataset_path', '/workspace/TencentPretrain/../datasets/preprocessed/chinese_llama_small_test1.pt', '--spm_model_path', '/workspace/TencentPretrain/../models/llama_chinese/tokenizer.model', '--config_path', 'models/llama/7b_config.json', '--output_model_path', '/workspace/TencentPretrain/../models/output/chinese_llama_7b_pretrain_small_test', '--world_size', '1', '--data_processor', 'lm', '--deepspeed_checkpoint_activations', '--total_steps', '10000', '--save_checkpoint_steps', '5000', '--batch_size', '24'] exits with return code = 1

The pretrained_model to start with is a model with about 7B parameters and the size of it is 13GB. I am trying to continue training based on this model.

Hardware-wise, I have one GPU (A100, 40G), and I have about 94G of cpu memory available, I assume it is not a very small one? (or it is?).

free -h
              total        used        free      shared  buff/cache   available
Mem:            94G        2.5G         89G         25M        2.5G         90G
Swap:            0B          0B          0B

My local environment:

deepspeed 0.8.3,  torch1.12.1, py3.7, cuda11.3, cudnn8.3.2.

I am new to deepspeed and traning large models, so do let me know if my description is not clear. I wanna know if there is any way to get around with this OOM issue? And how to calculate the memory required for deepspeed training with known pretrained model size? Thank you!

@treya-lin treya-lin added bug Something isn't working training labels May 25, 2023
@treya-lin treya-lin changed the title [BUG] How to calculate the cpu memory required for DeepSpeedZeRoOffload initialization? How to calculate the cpu memory required for DeepSpeedZeRoOffload initialization? May 25, 2023
@tjruwase
Copy link
Contributor

tjruwase commented Jun 2, 2023

@treya-lin, since you are offloading both parameters and optimizer state to CPU you would need roughly 18 bytes per model parameter. That means for 7B model you would need ~126GB of CPU memory. Please see page 3 of https://arxiv.org/pdf/1910.02054.pdf for a discussion of the memory breakdown.

@tjruwase tjruwase closed this as completed Jun 2, 2023
@treya-lin
Copy link
Author

@treya-lin, since you are offloading both parameters and optimizer state to CPU you would need roughly 18 bytes per model parameter. That means for 7B model you would need ~126GB of CPU memory. Please see page 3 of https://arxiv.org/pdf/1910.02054.pdf for a discussion of the memory breakdown.

Hi I see! Thanks a lot!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working training
Projects
None yet
Development

No branches or pull requests

2 participants