在 v6e TPU 上进行 MaxDiffusion 推理

本教程介绍了如何在 TPU v6e 上部署 MaxDiffusion 模型。在本教程中,您将使用 Stable Diffusion XL 模型生成图片。

准备工作

准备预配具有 4 个芯片的 TPU v6e:

  1. 按照设置 Cloud TPU 环境指南设置 Google Cloud 项目、配置 Google Cloud CLI、启用 Cloud TPU API,并确保您有权使用 Cloud TPU。

  2. 使用 Google Cloud 进行身份验证,并为 Google Cloud CLI 配置默认项目和区域。

    gcloud auth login
    gcloud config set project PROJECT_ID
    gcloud config set compute/zone ZONE

保障容量

当您准备好预订 TPU 容量时,请参阅 Cloud TPU 配额,详细了解 Cloud TPU 配额。如果您对如何确保容量还有其他疑问,请与您的 Cloud TPU 销售团队或客户支持团队联系。

预配 Cloud TPU 环境

您可以使用 GKE、GKE 和 XPK 预配 TPU 虚拟机,也可以将其作为队列化资源预配。

前提条件

  • 验证您的项目是否有足够的 TPUS_PER_TPU_FAMILY 配额,该配额指定您可以在Google Cloud 项目中访问的芯片数量上限。
  • 验证您的项目是否有足够的 TPU 配额:
    • TPU 虚拟机配额
    • IP 地址配额
    • Hyperdisk Balanced 配额
  • 用户项目权限

预配 TPU v6e

   gcloud alpha compute tpus queued-resources create QUEUED_RESOURCE_ID \
        --node-id TPU_NAME \
        --project PROJECT_ID \
        --zone ZONE \
        --accelerator-type v6e-4 \
        --runtime-version v2-alpha-tpuv6e \
        --service-account SERVICE_ACCOUNT

使用 listdescribe 命令查询队列中资源的状态。

   gcloud alpha compute tpus queued-resources describe QUEUED_RESOURCE_ID  \
      --project=PROJECT_ID --zone=ZONE

如需查看已加入队列的资源请求状态的完整列表,请参阅已加入队列的资源文档。

使用 SSH 连接到 TPU

   gcloud compute tpus tpu-vm ssh TPU_NAME

创建 Conda 环境

  1. 为 Miniconda 创建一个目录:

    mkdir -p ~/miniconda3
  2. 下载 Miniconda 安装程序脚本:

    wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O ~/miniconda3/miniconda.sh
  3. 安装 Miniconda:

    bash ~/miniconda3/miniconda.sh -b -u -p ~/miniconda3
  4. 移除 Miniconda 安装程序脚本:

    rm -rf ~/miniconda3/miniconda.sh
  5. 将 Miniconda 添加到 PATH 变量:

    export PATH="$HOME/miniconda3/bin:$PATH"
  6. 重新加载 ~/.bashrc 以将更改应用于 PATH 变量:

    source ~/.bashrc
  7. 创建一个新的 Conda 环境:

    conda create -n tpu python=3.10
  8. 激活 Conda 环境:

    source activate tpu

设置 MaxDiffusion

  1. 克隆 MaxDiffusion 代码库并进入 MaxDiffusion 目录:

    git clone https://github.com/google/maxdiffusion.git && cd maxdiffusion
  2. 切换到 mlperf-4.1 分支:

    git checkout mlperf4.1
  3. 安装 MaxDiffusion:

    pip install -e .
  4. 安装依赖项:

    pip install -r requirements.txt
  5. 安装 JAX:

    pip install jax[tpu]==0.4.34 jaxlib==0.4.34 ml-dtypes==0.2.0 -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
  6. 安装其他依赖项:

     pip install huggingface_hub==0.25 absl-py flax tensorboardX google-cloud-storage torch tensorflow transformers 

生成图片

  1. 设置环境变量以配置 TPU 运行时:

    LIBTPU_INIT_ARGS="--xla_tpu_rwb_fusion=false --xla_tpu_dot_dot_fusion_duplicated=true --xla_tpu_scoped_vmem_limit_kib=65536"
  2. 使用 src/maxdiffusion/configs/base_xl.yml 中定义的提示和配置生成图片:

    python -m src.maxdiffusion.generate_sdxl src/maxdiffusion/configs/base_xl.yml run_name="my_run"

    生成映像后,请务必清理 TPU 资源。

清理

删除 TPU:

gcloud compute tpus queued-resources delete QUEUED_RESOURCE_ID \
    --project PROJECT_ID \
    --zone ZONE \
    --force \
    --async