AIGC笔记--Diffuser的训练pipeline

1--简单训练pipeline

import time
import numpy as np
import torch
from PIL import Image
import torchvision
import torch.nn.functional as F
from datasets import load_dataset
from torchvision import transforms
from matplotlib import pyplot as plt
from diffusers import DDPMScheduler, UNet2DModel, DDPMPipeline

# 数据增广
def transform(examples):
    preprocess = transforms.Compose(
        [
            transforms.Resize((32, 32)),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize([0.5], [0.5]),
        ]
    )
    images = [preprocess(image.convert("RGB")) for image in examples["image"]]
    return {"images": images}

def process_dataset(batch_size):
    # 加载数据集
    # dataset = load_dataset("huggan/smithsonian_butterflies_subset", split = "train")
    dataset = load_dataset("/data-home/liujinfu/Diffuser/Data/smithsonian_butterflies_subset", split = "train")
    # 调用自定义的transform函数
    dataset.set_transform(transform)
    # 设置dataloader
    train_dataloader = torch.utils.data.DataLoader(
        dataset, 
        batch_size = batch_size, 
        shuffle = True
    )
    return train_dataloader

def train_loop(train_dataloader, noise_scheduler, model, num_epoches, device):
    # 优化器
    optimizer = torch.optim.AdamW(model.parameters(), lr = 4e-4)
    losses = []
    start_time = time.time() 
    for epoch in range(num_epoches):
        for _, batch in enumerate(train_dataloader): # 遍历
            clean_images = batch["images"].to(device) # B C H W
            # sample noise to add to the images
            noise = torch.randn(clean_images.shape).to(clean_images.device) # B C H W
            bs = clean_images.shape[0] # 64

            # sample a random timestep for each image
            timesteps = torch.randint(
                0, noise_scheduler.num_train_timesteps, (bs, ), device = clean_images.device
            ).long() # B

            # Add noise to the clean images according to the noise magnitude at each timestep
            noisy_images = noise_scheduler.add_noise(clean_images, noise, timesteps) # 加噪

            # Get model prediction
            noise_pred = model(noisy_images, timesteps, return_dict=False)[0]

            # Calculate the loss
            loss = F.mse_loss(noise_pred, noise) # 计算预测噪音和真实噪音之间的损失
            loss.backward(loss)
            losses.append(loss.item())

            # Update the model parameters with the optimizer
            optimizer.step()
            optimizer.zero_grad()

        if (epoch + 1) % 5 == 0:
            loss_last_epoch = sum(losses[-len(train_dataloader) :]) / len(train_dataloader)
            print(f"Epoch:{epoch+1}, loss: {loss_last_epoch}")

    end_time = time.time()
    elapsed_time = end_time - start_time # 记录训练时间
    
    print("time cost: ", elapsed_time)
    return losses

def vis(losses):
    # 可视化 loss
    fig, axs = plt.subplots(1, 2, figsize=(12, 4))
    axs[0].plot(losses)
    axs[1].plot(np.log(losses))
    return fig

def generate(model, noise_scheduler):
    # 1. create a pipeline
    image_pipe = DDPMPipeline(unet = model, scheduler = noise_scheduler)

    pipeline_output = image_pipe()
    return pipeline_output.images[0]

# 可视化生成图像
def show_images(x):
    """Given a batch of images x, make a grid and convert to PIL"""
    x = x * 0.5 + 0.5  # Map from (-1, 1) back to (0, 1)
    grid = torchvision.utils.make_grid(x)
    grid_im = grid.detach().cpu().permute(1, 2, 0).clip(0, 1) * 255
    grid_im = Image.fromarray(np.array(grid_im).astype(np.uint8))
    return grid_im

def make_grid(images, size=64):
    """Given a list of PIL images, stack them together into a line for easy viewing"""
    output_im = Image.new("RGB", (size * len(images), size))
    for i, im in enumerate(images):
        output_im.paste(im.resize((size, size)), (i * size, 0))
    return output_im

def main():
    # 获取训练集
    image_size = 32 
    batch_size = 64
    train_dataloader = process_dataset(batch_size = batch_size)

    # 设置Scheduler
    noise_scheduler = DDPMScheduler(num_train_timesteps = 1000, beta_schedule = "squaredcos_cap_v2") 

    # 创建Unet model
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = UNet2DModel(
        sample_size = image_size, # target image resolution
        in_channels = 3,
        out_channels = 3,
        layers_per_block = 2, # how many resnet layers to use per Unet block
        block_out_channels = (64, 128,128, 256),
        down_block_types = (
            "DownBlock2D",
            "DownBlock2D",
            "AttnDownBlock2D",
            "AttnDownBlock2D",
        ),
        up_block_types=(
            "AttnUpBlock2D",
            "AttnUpBlock2D", # a ResNet upsampling block with spatial self-attention
            "UpBlock2D",
            "UpBlock2D", # a regular ResNet upsampling block
        ),  
    ).to(device)
    
    # 开始训练
    losses = train_loop(train_dataloader = train_dataloader, 
               noise_scheduler = noise_scheduler,
               model = model,
               num_epoches = 30, 
               device = device)
    
    fig = vis(losses)
    fig.savefig("./loss.png")
    
    # 生成一张图片
    gen_img = generate(model, noise_scheduler)
    gen_img.save("./generate1.png")
    
    # 随机初始化噪音生成图片
    sample = torch.randn(8, 3, 32, 32).to(device)
    for i, t in enumerate(noise_scheduler.timesteps): # 反向去噪
        # Get model pred
        with torch.no_grad():
            residual = model(sample, t).sample
        # Update sample with step
        sample = noise_scheduler.step(residual, t, sample).prev_sample
    
    # 可视化生成的图片
    grid_im = show_images(sample)
    grid_im.save("./genearate2.png")
    print("All Done!")

if __name__ == "__main__":
    main()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/608236.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

TC8002D(3W音频功放IC)是一颗带关断模式的音频功放IC

一、概述 TC8002D是一颗带关断模式的音频功放IC。在5V输入电压下工作时,负载(3Ω)上的平均功率为3W,且失真度不超过10%。而对于手提设备而言,当VDD作用于关断端时,TC8002D将会进入关断模式,此时的功耗极低&…

机器学习算法之KNN分类算法【附python实现代码!可运行】

一、简介 在机器学习中,KNN(k-Nearest Neighbors)分类算法是一种简单且有效的监督学习算法,主要用于分类问题。KNN算法的基本思想是:在特征空间中,如果一个样本在特征空间中的k个最相邻的样本中的大多数属…

常见的一些RELAXED MODEL CONCEPTS

释放一致性(release consistency, RC) RC的核心观点是:使用 FENCE 围绕所有同步操作是多余的 同步获取 (acquire) 只需要一个后续的 FENCE,同步释放 (release) 只需要一个前面的 FENCE。 对于表 5.4 的临界区示例,可以省略 FENCE F11、F14…

Linux-笔记 修改开发板默认时区

1. 时区文件 使用命令date -R查看当前的默认时区,date - R命令会自动解析/etc/localtime 文件,而该文件又是指向“ /usr/share/zoneinfo/$主时区/$次时区 ”,当需要更改到指定的时区只要将/etc/localtime 文件软链接到 ”/usr/share/zoneinf…

Vue的省份联动

Vue的省份联动 一、安装依赖库 npm install element-china-area-data -Snpm install element-ui --save全局使用elemntui组件库 import ElementUI from element-ui; import element-ui/lib/theme-chalk/index.css;Vue.use(ElementUI);二 、代码如下 <template><div…

HarmonyOS开发之ArkTS使用:用户登录页面应用

目录 目录 前言 关于HarmonyOS 环境准备 新建项目 设计用户登录页面 1. 布局设计 2. 编写ArkTS代码 运行和测试 结束语 前言 随着HarmonyOS&#xff08;鸿蒙操作系统&#xff09;的不断发展&#xff0c;越来越多的开发者开始投入到这个全新的生态系统中&#xff0c;而…

BeyondCompare4 下载\安装\免费使用

1. 官网 下载 Download Beyond Compare Free Trial 2. 安装&#xff08;无脑下一步&#xff09; 3.永久免费使用 修改注册表 A、在搜索栏中输入 regedit &#xff0c;打开注册表 B、 删除项目&#xff1a;计算机 \HKEY_CURRENT_USER\Software\ScooterSoftware\Beyond Compar…

物联网实战--平台篇之(五)账户界面

目录 一、界面框架 二、首页(未登录) 三、验证码登录 四、密码登录 五、帐号注册 六、忘记密码 本项目的交流QQ群:701889554 物联网实战--入门篇https://blog.csdn.net/ypp240124016/category_12609773.html 物联网实战--驱动篇https://blog.csdn.net/ypp240124016/cat…

10. Django Auth认证系统

10. Auth认证系统 Django除了内置的Admin后台系统之外, 还内置了Auth认证系统. 整个Auth认证系统可分为三大部分: 用户信息, 用户权限和用户组, 在数据库中分别对应数据表auth_user, auth_permission和auth_group.10.1 内置User实现用户管理 用户管理是网站必备的功能之一, D…

远动通讯屏,组成和功能介绍

远动通讯屏&#xff0c;组成和功能介绍 远动通讯屏是基于电网安全建设而投入的远方监控厂站信息、远方切除电网负荷的设备&#xff1b;主经是由远动装置、通讯管理机、交换机、GPS对时装置、数字通道防雷器、模拟通道防雷器、屏柜及附件等设备组成。变电站远动通讯系统是指对广…

安装oh-my-zsh(命令行工具)

文章目录 一、安装zsh、git、wget二、安装运行脚本1、curl/wget下载2、手动下载 三、切换主题1、编辑配置文件2、切换主题 四、安装插件1、zsh-syntax-highlighting&#xff08;高亮语法错误&#xff09;2、zsh-autosuggestions&#xff08;自动补全&#xff09; 五、更多优化配…

顺序表的实现(迈入数据结构的大门)(2)

目录 顺序表的头插(SLPushFront) 此时&#xff1a;我们有两个思路&#xff08;数组移位&#xff09; 顺序表的头删(学会思维的变换)(SLPopFront) 顺序表的尾插(SLPushBack) 有尾插就有尾删 既然头与尾部的插入与删除都有&#xff0c;那必然少不了指定位置的插入删除 查找…

汽车之家,如何在“以旧换新”浪潮中大展拳脚?

北京车展刚刚落幕&#xff0c;两重利好正主导汽车市场持续升温&#xff1a;新能源渗透率首破50%&#xff0c;以及以旧换新详细政策进入落地期。 图源&#xff1a;中国政府网 在政策的有力指引下&#xff0c;汽车产业链的各个环节正经历着一场深刻的“连锁反应”。在以旧换新的…

\boldsymbol无法使用

检查是否导入了 unicode-math 宏包、 没有加粗效果 正常加粗了 2024-5-9-15点35分

(八)JSP教程——application对象

application对象是一个比较重要的对象&#xff0c;服务器在启动后就会产生这个application对象&#xff0c;所有连接到服务器的客户端application对象都是相同的&#xff0c;所有的客户端共享这个内置的application对象&#xff0c;直到服务器关闭为止。 可以使用application对…

【SpringBoot记录】自动配置原理(1):依赖管理

前言 我们都知道SpringBoot能快速创建Spring应用&#xff0c;其核心优势就在于自动配置功能&#xff0c;它通过一系列的约定和内置的配置来减少开发者手动配置的工作。下面通过最简单的案例分析SpringBoot的功能特性&#xff0c;了解自动配置原理。 SpringBoot简单案例 根据S…

Linux下的SPI通信

SPI通信 一. 1.SPI简介: SPI 是一种高速,全双工,同步串行总线。 SPI 有主从俩种模式通常由一个主设备和一个或者多个从设备组从。SPI不支持多主机。 SPI通信至少需要四根线,分别是 MISO(主设备数据输入,从设备输出),MOSI (主设数据输出从设备输入),SCLK(时钟信号),CS/SS…

leetcode尊享面试100题(549二叉树最长连续序列||,python)

题目不长&#xff0c;就是分析时间太久了。 思路使用dfs深度遍历&#xff0c;先想好这个函数返回什么&#xff0c;题目给出路径可以是子-父-子的路径&#xff0c;那么1-2-3可以&#xff0c;3-2-1也可以&#xff0c;那么考虑dfs返回两个值&#xff0c;对于当前节点node来说&…

BI赋能金融新质生产力,16家金融机构智能BI创新实践分享

2024年政府工作报告强调&#xff0c;要“大力发展科技金融、绿色金融、普惠金融、养老金融、数字金融”&#xff0c;同时“大力推进现代化产业体系建设&#xff0c;加快发展新质生产力”。对于金融行业而言&#xff0c;培育新质生产力是高质量发展的关键着力点。金融机构可以通…

vue项目启动后页面显示‘Cannot GET /’

1、npm run dev命令启动项目的时候没有报错&#xff0c;页面打开却提示 Cannot GET / 2.这个时候只需要找到config文件夹下面的index.js文件。把assetsPublicPath字符串的&#xff1a;‘./’修改成 ‘/’就行了。修改完之后记得关闭项目&#xff0c;然后重新启动。不然不会生效…
最新文章