UNet2DConditionModel总体结构
stable diffusion 运行unet部分的代码。
noise_pred = self.unet(
sample=latent_model_input, #(2,4,64,64) 生成的latent
timestep=t, #时刻t
encoder_hidden_states=prompt_embeds, #(2,77,768) #输入的prompt和negative prompt 生成的embedding
timestep_cond=timestep_cond,#默认空
cross_attention_kwargs=self.cross_attention_kwargs, #默认空
added_cond_kwargs=added_cond_kwargs, #默认空
return_dict=False,
)[0]
1.time
get_time_embed使用了sinusoidal timestep embeddings,
time_embedding 使用了两个线性层和激活层进行映射,将320转换到1280。
如果还有class_labels,added_cond_kwargs等参数,也转换为embedding,并且相加。
t_emb = self.get_time_embed(sample=sample, timestep=timestep) #(2,320)
emb = self.time_embedding(t_emb, timestep_cond) #(2,1280)
2.pre-process
卷积转换,输入latent从(2,4,64,64) 到(2,320,64,64)
sample = self.conv_in(sample) #(2,320,64,64)
self.conv_in = nn.Conv2d(
in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
)
3.down
down_block 由三个CrossAttnDownBlock2D和一个DownBlock2D组成。输入包括:
-
hidden_states:latent
-
temb:时刻t的embdedding
-
encoder_hidden_states:prompt和negative prompt的embedding
网络结构
CrossAttnDownBlock2D(
ResnetBlock2D