libtorch实现vgg图像分类

#include <torch/torch.h>
#include <vector>
#include <string>

// VGG配置结构体
struct VGGConfig {
    std::vector<int> conv_layers;
    bool use_batch_norm;
    int num_classes;
};

class VGGImpl : public torch::nn::Module {
public:
    explicit VGGImpl(const VGGConfig& config) {
        // 构建特征提取器
        int in_channels = 3;
        int layer_idx = 0;
        
        for (size_t i = 0; i < config.conv_layers.size(); ++i) {
            int out_channels = config.conv_layers[i];
            
            if (out_channels == -1) { // 最大池化层
                features->push_back(
                    torch::nn::MaxPool2d(torch::nn::MaxPool2dOptions(2).stride(2)));
            } else { // 卷积层
                auto conv = torch::nn::Conv2d(
                    torch::nn::Conv2dOptions(in_channels, out_channels, 3)
                        .padding(1));
                
                features->push_back(conv);
                register_module("conv_" + std::to_string(++layer_idx), conv);
                
                if (config.use_batch_norm) {
                    auto bn = torch::nn::BatchNorm2d(out_channels);
                    features->push_back(bn);
                    register_module("bn_" + std::to_string(layer_idx), bn);
                }
                
                features->push_back(torch::nn::ReLU());
                in_channels = out_channels;
            }
        }
        
        // 分类器
        classifier = register_module("classifier", 
            torch::nn::Sequential(
                torch::nn::Linear(512 * 7 * 7, 4096),
                torch::nn::ReLU(),
                torch::nn::Dropout(0.5),
                torch::nn::Linear(4096, 4096),
                torch::nn::ReLU(),
                torch::nn::Dropout(0.5),
                torch::nn::Linear(4096, config.num_classes)
        );
    }

    torch::Tensor forward(torch::Tensor x) {
        x = features->forward(x);
        x = torch::flatten(x, 1); // 保持batch维度
        x = classifier->forward(x);
        return x;
    }

private:
    torch::nn::Sequential features{nullptr};
    torch::nn::Sequential classifier{nullptr};
};

TORCH_MODULE(VGG);

// 创建不同版本的VGG
VGG create_vgg(int version = 16, int num_classes = 1000, bool batch_norm = false) {
    std::vector<int> config;
    
    switch(version) {
        case 11:
            config = {64, -1, 128, -1, 256, 256, -1, 512, 512, -1, 512, 512, -1};
            break;
        case 13:
            config = {64, 64, -1, 128, 128, -1, 256, 256, -1, 512, 512, -1, 512, 512, -1};
            break;
        case 16:
            config = {64, 64, -1, 128, 128, -1, 256, 256, 256, -1, 512, 512, 512, -1, 512, 512, 512, -1};
            break;
        case 19:
            config = {64, 64, -1, 128, 128, -1, 256, 256, 256, 256, -1, 
                     512, 512, 512, 512, -1, 512, 512, 512, 512, -1};
            break;
        default:
            throw std::runtime_error("Unsupported VGG version");
    }
    
    return VGG(VGGConfig{config, batch_norm, num_classes});
}
#include <torch/torch.h>
#include <opencv2/opencv.hpp>
#include <vector>
#include <string>
#include <filesystem>

// 自定义数据集类
class CustomDataset : public torch::data::Dataset<CustomDataset> {
public:
    explicit CustomDataset(const std::string& root_dir, 
                         const std::vector<std::string>& class_names,
                         int image_size = 224,
                         bool is_train = true) 
        : image_size_(image_size), is_train_(is_train) {
        
        // 扫描目录结构并建立标签映射
        int label = 0;
        for (const auto& class_name : class_names) {
            std::string class_dir = root_dir + "/" + class_name;
            
            for (const auto& entry : std::filesystem::directory_iterator(class_dir)) {
                if (entry.path().extension() == ".jpg" || 
                    entry.path().extension() == ".png") {
                    image_paths_.push_back(entry.path().string());
                    labels_.push_back(label);
                }
            }
            label++;
        }
        
        // 建立标签到类名的映射
        class_map_ = class_names;
    }

    // 获取单个样本
    torch::data::Example<> get(size_t index) override {
        cv::Mat image = cv::imread(image_paths_[index]);
        if (image.empty()) {
            throw std::runtime_error("Failed to load image: " + image_paths_[index]);
        }
        
        // 数据增强和预处理
        image = preprocess_image(image);
        
        // 转换为Tensor
        auto tensor = torch::from_blob(
            image.data, {image.rows, image.cols, 3}, torch::kByte);
        tensor = tensor.permute({2, 0, 1}).to(torch::kFloat32).div_(255);
        
        // 标准化 (ImageNet标准)
        tensor[0] = tensor[0].sub_(0.485).div_(0.229);
        tensor[1] = tensor[1].sub_(0.456).div_(0.224);
        tensor[2] = tensor[2].sub_(0.406).div_(0.225);
        
        int64_t label = labels_[index];
        return {tensor, torch::tensor(label)};
    }

    // 返回数据集大小
    torch::optional<size_t> size() const override {
        return image_paths_.size();
    }

    // 获取类名
    std::string get_class_name(int label) const {
        return class_map_.at(label);
    }

private:
    cv::Mat preprocess_image(cv::Mat image) {
        // 训练时的数据增强
        if (is_train_) {
            // 随机水平翻转
            if (torch::rand(1).item<float>() > 0.5) {
                cv::flip(image, image, 1);
            }
            
            // 随机裁剪
            int h = image.rows;
            int w = image.cols;
            int new_h = h * 0.8 + torch::rand(1).item<float>() * h * 0.2;
            int new_w = w * 0.8 + torch::rand(1).item<float>() * w * 0.2;
            
            cv::resize(image, image, cv::Size(new_w, new_h));
            int y = torch::rand(1).item<int>() % (new_h - image_size_);
            int x = torch::rand(1).item<int>() % (new_w - image_size_);
            cv::Rect roi(x, y, image_size_, image_size_);
            image = image(roi).clone();
        } 
        // 验证/测试时的确定性处理
        else {
            cv::resize(image, image, cv::Size(image_size_ * 256 / 224, image_size_ * 256 / 224));
            int center_y = image.rows / 2;
            int center_x = image.cols / 2;
            cv::Rect roi(center_x - image_size_/2, center_y - image_size_/2, image_size_, image_size_);
            image = image(roi).clone();
        }
        
        return image;
    }

    std::vector<std::string> image_paths_;
    std::vector<int64_t> labels_;
    std::vector<std::string> class_map_;
    int image_size_;
    bool is_train_;
};
#include "vgg.h"
#include "custom_dataset.h"
#include <torch/torch.h>
#include <iostream>
#include <memory>

void train_custom(const std::string& train_dir, 
                 const std::string& val_dir,
                 const std::vector<std::string>& class_names,
                 int num_epochs = 50,
                 int batch_size = 32) {
    // 1. 创建模型
    auto model = create_vgg(16, class_names.size()); // VGG16 with custom classes
    model->to(torch::kCUDA);
    
    // 2. 加载自定义数据集
    auto train_dataset = CustomDataset(train_dir, class_names, 224, true)
        .map(torch::data::transforms::Stack<>());
    
    auto val_dataset = CustomDataset(val_dir, class_names, 224, false)
        .map(torch::data::transforms::Stack<>());
    
    auto train_loader = torch::data::make_data_loader(
        std::move(train_dataset),
        torch::data::DataLoaderOptions()
            .batch_size(batch_size)
            .workers(4)
            .shuffle(true));
    
    auto val_loader = torch::data::make_data_loader(
        std::move(val_dataset),
        torch::data::DataLoaderOptions()
            .batch_size(batch_size)
            .workers(4)
            .shuffle(false));
    
    // 3. 定义优化器和损失函数
    torch::optim::Adam optimizer(
        model->parameters(),
        torch::optim::AdamOptions(1e-4).weight_decay(1e-4));
    
    auto criterion = torch::nn::CrossEntropyLoss();
    
    // 4. 训练循环
    for (int epoch = 1; epoch <= num_epochs; ++epoch) {
        // 训练阶段
        model->train();
        float running_loss = 0.0;
        int correct = 0;
        int total = 0;
        
        for (auto& batch : *train_loader) {
            auto data = batch.data.to(torch::kCUDA);
            auto targets = batch.target.to(torch::kCUDA);
            
            optimizer.zero_grad();
            auto outputs = model->forward(data);
            auto loss = criterion(outputs, targets);
            loss.backward();
            optimizer.step();
            
            running_loss += loss.item<float>();
            auto predicted = torch::argmax(outputs, 1);
            total += targets.size(0);
            correct += (predicted == targets).sum().item<int>();
        }
        
        float train_acc = 100.0 * correct / total;
        float train_loss = running_loss / total;
        
        // 验证阶段
        model->eval();
        running_loss = 0.0;
        correct = 0;
        total = 0;
        
        for (auto& batch : *val_loader) {
            auto data = batch.data.to(torch::kCUDA);
            auto targets = batch.target.to(torch::kCUDA);
            
            auto outputs = model->forward(data);
            auto loss = criterion(outputs, targets);
            
            running_loss += loss.item<float>();
            auto predicted = torch::argmax(outputs, 1);
            total += targets.size(0);
            correct += (predicted == targets).sum().item<int>();
        }
        
        float val_acc = 100.0 * correct / total;
        float val_loss = running_loss / total;
        
        // 打印统计信息
        std::cout << "Epoch [" << epoch << "/" << num_epochs << "]\n"
                  << "Train Loss: " << train_loss << " Acc: " << train_acc << "%\n"
                  << "Val Loss: " << val_loss << " Acc: " << val_acc << "%\n\n";
    }
    
    // 5. 保存模型
    torch::save(model, "custom_vgg16_model.pt");
}

int main() {
    try {
        // 示例类名 (根据实际数据集修改)
        std::vector<std::string> class_names = {"cat", "dog", "bird"};
        
        train_custom("data/train", "data/val", class_names, 50, 32);
    } catch (const std::exception& e) {
        std::cerr << "Error: " << e.what() << std::endl;
        return 1;
    }
    return 0;
}
#include "vgg.h"
#include "custom_dataset.h"
#include <torch/torch.h>
#include <opencv2/opencv.hpp>
#include <iostream>

int main() {
    try {
        // 1. 加载模型 (必须与训练时相同的类名顺序)
        std::vector<std::string> class_names = {"cat", "dog", "bird"};
        auto model = create_vgg(16, class_names.size());
        torch::load(model, "custom_vgg16_model.pt");
        model->eval();
        model->to(torch::kCUDA);
        
        // 2. 预处理图像 (使用与验证集相同的方式)
        cv::Mat image = cv::imread("test_image.jpg");
        if (image.empty()) {
            throw std::runtime_error("Failed to load image");
        }
        
        // 验证/测试时的确定性处理
        int image_size = 224;
        cv::resize(image, image, cv::Size(image_size * 256 / 224, image_size * 256 / 224));
        int center_y = image.rows / 2;
        int center_x = image.cols / 2;
        cv::Rect roi(center_x - image_size/2, center_y - image_size/2, image_size, image_size);
        image = image(roi).clone();
        
        // 转换为Tensor
        auto tensor = torch::from_blob(
            image.data, {image.rows, image.cols, 3}, torch::kByte);
        tensor = tensor.permute({2, 0, 1}).to(torch::kFloat32).div_(255);
        
        // 标准化
        tensor[0] = tensor[0].sub_(0.485).div_(0.229);
        tensor[1] = tensor[1].sub_(0.456).div_(0.224);
        tensor[2] = tensor[2].sub_(0.406).div_(0.225);
        
        // 3. 运行推理
        auto input_tensor = tensor.unsqueeze(0).to(torch::kCUDA);
        auto output = model->forward(input_tensor);
        auto probs = torch::softmax(output, 1);
        auto predicted_idx = torch::argmax(probs).item<int>();
        
        // 4. 输出结果
        std::cout << "Predicted class: " << class_names[predicted_idx] 
                  << " (" << predicted_idx << ")\n";
        std::cout << "Confidence: " << probs[0][predicted_idx].item<float>() * 100 << "%\n";
        
        // 输出所有类别的概率
        for (size_t i = 0; i < class_names.size(); ++i) {
            std::cout << class_names[i] << ": " 
                      << probs[0][i].item<float>() * 100 << "%\n";
        }
        
    } catch (const std::exception& e) {
        std::cerr << "Error: " << e.what() << std::endl;
        return 1;
    }
    return 0;
}
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值