Pytorch深度学习完整GPU图像分类代码

1. CPU与GPU不同

1.输入数据
2.网络模型
3.损失函数
.cuda()

  1. 说明:下面代码中GPU版本中取消下划线的即为CPU版本

2.完成的分类代码(GPU)

import torch
import torchvision
from torch.utils.tensorboard import SummaryWriter

# from model import *
# 准备数据集
from torch import nn
from torch.utils.data import DataLoader

# 定义训练的设备
~~device = torch.device("cuda")~~ 

train_data = torchvision.datasets.CIFAR10(root="../data", train=True, transform=torchvision.transforms.ToTensor(),
                                          download=True)
test_data = torchvision.datasets.CIFAR10(root="../data", train=False, transform=torchvision.transforms.ToTensor(),
                                         download=True)

# length 长度
train_data_size = len(train_data)
test_data_size = len(test_data)
# 如果train_data_size=10, 训练数据集的长度为:10
print("训练数据集的长度为:{}".format(train_data_size))
print("测试数据集的长度为:{}".format(test_data_size))


# 利用 DataLoader 来加载数据集
train_dataloader = DataLoader(train_data, batch_size=64)
test_dataloader = DataLoader(test_data, batch_size=64)

# 创建网络模型
class Tudui(nn.Module):
    def __init__(self):
        super(Tudui, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(3, 32, 5, 1, 2),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 32, 5, 1, 2),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 64, 5, 1, 2),
            nn.MaxPool2d(2),
            nn.Flatten(),
            nn.Linear(64*4*4, 64),
            nn.Linear(64, 10)
        )

    def forward(self, x):
        x = self.model(x)
        return x
tudui = Tudui()
~~tudui = tudui.to(device)~~ 

# 损失函数
loss_fn = nn.CrossEntropyLoss()
~~loss_fn = loss_fn.to(device)~~ 
# 优化器
# learning_rate = 0.01
# 1e-2=1 x (10)^(-2) = 1 /100 = 0.01
learning_rate = 1e-2
optimizer = torch.optim.SGD(tudui.parameters(), lr=learning_rate)

# 设置训练网络的一些参数
# 记录训练的次数
total_train_step = 0
# 记录测试的次数
total_test_step = 0
# 训练的轮数
epoch = 10

# 添加tensorboard
writer = SummaryWriter("../logs_train")

for i in range(epoch):
    print("-------第 {} 轮训练开始-------".format(i+1))

    # 训练步骤开始
    tudui.train()
    for data in train_dataloader:
        imgs, targets = data
        ~~imgs = imgs.to(device)
        targets = targets.to(device)~~ 
        outputs = tudui(imgs)
        loss = loss_fn(outputs, targets)

        # 优化器优化模型
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_train_step = total_train_step + 1
        if total_train_step % 100 == 0:
            print("训练次数:{}, Loss: {}".format(total_train_step, loss.item()))
            writer.add_scalar("train_loss", loss.item(), total_train_step)

    # 测试步骤开始
    tudui.eval()
    total_test_loss = 0
    total_accuracy = 0
    with torch.no_grad():
        for data in test_dataloader:
            imgs, targets = data
            ~~imgs = imgs.to(device)
            targets = targets.to(device)~~ 
            outputs = tudui(imgs)
            loss = loss_fn(outputs, targets)
            total_test_loss = total_test_loss + loss.item()
            accuracy = (outputs.argmax(1) == targets).sum()
            total_accuracy = total_accuracy + accuracy

    print("整体测试集上的Loss: {}".format(total_test_loss))
    print("整体测试集上的正确率: {}".format(total_accuracy/test_data_size))
    writer.add_scalar("test_loss", total_test_loss, total_test_step)
    writer.add_scalar("test_accuracy", total_accuracy/test_data_size, total_test_step)
    total_test_step = total_test_step + 1

    torch.save(tudui, "tudui_{}.pth".format(i))
    print("模型已保存")

writer.close()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/549294.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【高通平台】如何升级蓝牙的firmware

1. 您可以使用以下命令升级固件 adb push apbtfw11.tlv /bt_firmware/image/ adb push apnv11.bin /bt_firmware/image/ adb shell sync 或者 adb push crbtfw21.tlv /vendor/bt_firmware/image adb push crnv21.bin /vendor/bt_firmware/image adb shell sync 查看代码路径…

项目5-博客系统3+接口完

1.实现显示用户信息 ⽬前⻚⾯的⽤⼾信息部分是写死的. 形如 我们期望这个信息可以随着用户登陆而发生改变. • 如果当前⻚⾯是博客列表⻚, 则显⽰当前登陆⽤⼾的信息. • 如果当前⻚⾯是博客详情⻚, 则显⽰该博客的作者⽤⼾信息. 注意: 当前我们只是实现了显⽰⽤⼾名, 没有…

Linux学习(1)

参考学习韩顺平老师 第一章:LINUX 开山篇-内容介绍 1.2.Linux 使用在那些地方 linux运营工程师主要做: 服务器规划 调试优化 对系统进行日常监控 故障处理 对数据的备份和处理 日志的分析 1.3.Linux 的应用领域 1.个人桌面领域的应用 …

短视频底层逻辑分析

短视频底层逻辑 1.迭代模型_ev 2.Douyin的本质_ev 3.Douyin的审核机制_ev 4.平台趋势_ev 5.定位_ev 6.建立用户期待_ev 7.好内容的定义_ev 8怎么做好内容_ev 9.如何做好选题_ev 10.如何快速模仿_ev 11.账号拆解的底层逻辑_ev 12选人的重要性_ev 13.内容的包装_ev 14.打造大IP的…

easyexcel升级3.3.4失败的经历

原本想通过easyexcel从2.2.6升级到3.3.3解决一部分问题,结果之前的可以用的代码,却无端的出现bug 1 Sheet index (1) is out of range (0…0) 什么都没有改,就出了问题,那么问题肯定出现在easyexcel版本自身.使用模板填充的方式进…

Innodb之redo日志

Innodb引擎执行流程 redo log ​ MySQL中的redo log(重做日志)是实现WAL(预写式日志)技术的关键组件,用于确保事务的持久性和数据库的crash-safe能力。借用《孔乙己》中酒店掌柜使用粉板记录赊账的故事,…

Flask前端页面文本框展示后端变量,路由函数内外两类

一、外&#xff01;路由函数外的前后端数据传输 Flask后端 ↓ 首先导入包&#xff0c;需要使用 后端&#xff1a;flask_socketio来进行路由外的数据传输&#xff0c; from flask_socketio import SocketIO, emit 前端&#xff1a;还有HTML头文件的设置。 <!DOCTYPE …

【云原生数据库:原理与实践】1- 数据库发展历程

1-数据库发展历程 1.1 数据库发展概述 从1960年&#xff1a;Integrated Database System&#xff08;IDS&#xff09;&#xff0c;该系统是一个网状模型&#xff08;Network Model&#xff09;到 IMS&#xff08;Information Management System&#xff09;&#xff0c;使用了…

Rust腐蚀服务器清档多教程

Rust腐蚀服务器清档多教程 大家好我是艾西&#xff0c;一个做服务器租用的网络架构师。上期教了大家怎么搭建服务器以及安装插件等那么随着大家自己架设服或是玩耍的时间肯定会有小伙伴想要去新增开区数量或是把原本的服务器进行一些调整等&#xff0c;那么今天主要聊的就是怎…

【智能算法】鸭群算法(DSA)原理及实现

目录 1.背景2.算法原理2.1算法思想2.2算法过程 3.结果展示4.参考文献 1.背景 2021年&#xff0c;Zhang等人受到自然界鸭群觅食行为启发&#xff0c;提出了鸭群算法&#xff08;Duck Swarm Algorithm, DSA&#xff09;。 2.算法原理 2.1算法思想 DSA基于自然界鸭群觅食过程&…

[leetcode] max-area-of-island

. - 力扣&#xff08;LeetCode&#xff09; 给你一个大小为 m x n 的二进制矩阵 grid 。 岛屿 是由一些相邻的 1 (代表土地) 构成的组合&#xff0c;这里的「相邻」要求两个 1 必须在 水平或者竖直的四个方向上 相邻。你可以假设 grid 的四个边缘都被 0&#xff08;代表水&…

基于SpringBoot+Vue的共享汽车管理系统(源码+文档+包运行)

一.系统概述 随着信息技术在管理上越来越深入而广泛的应用&#xff0c;管理信息系统的实施在技术上已逐步成熟。本文介绍了共享汽车管理系统的开发全过程。通过分析共享汽车管理系统管理的不足&#xff0c;创建了一个计算机管理共享汽车管理系统的方案。文章介绍了共享汽车管理…

从 MySQL 到 DynamoDB,Canva 如何应对每天新增的 5000 万素材

作为一款设计工具&#xff0c;Canva 吸引人的一个重要特色就是拥有数以亿计的照片和图形资源&#xff0c;支持用户上传个人素材。 Canva 于 2013 年推出&#xff0c;设立了一个包含大量照片和图形的资源库&#xff0c;并允许用户上传自己的素材以用于设计。从发布之日起&#…

Python数学建模学习-PageRank算法

1-基本概念 PageRank算法是由Google创始人Larry Page在斯坦福大学时提出&#xff0c;又称PR&#xff0c;佩奇排名。主要针对网页进行排名&#xff0c;计算网站的重要性&#xff0c;优化搜索引擎的搜索结果。PR值是表示其重要性的因子。 中心思想&#xff1a; 数量假设&#…

JavaScript基础:js介绍、变量、数据类型以及类型转换

目录 介绍 引入方式 内部方式 外部形式 注释和结束符 单行注释 多行注释 结束符 输入和输出 输出 输入 变量 声明 赋值 关键字 变量名命名规则 常量 数据类型 数值类型 字符串类型 布尔类型 undefined 类型转换 隐式转换 显式转换 Number ✨介绍 &a…

typescript中的type关键字和interface关键字区别

Type又叫类型别名&#xff08;type alias&#xff09;,作用是给一个类型起一个新名字&#xff0c;不仅支持interface定义的对象结构&#xff0c;还支持基本类型、联合类型、交叉类型、元组等任何你需要手写的类型。 type num number; // 基本类型 type stringOrNum string |…

47.HarmonyOS鸿蒙系统 App(ArkUI)创建轮播效果

创建轮播效果&#xff0c;共3页切换 Entry Component struct Index {State message: string Hello Worldprivate swiperController: SwiperController new SwiperController()build() {Swiper(this.swiperController) {Text("第一页").width(90%).height(100%).bac…

BLE架构图

PHY层(Physical layer 物理层) PHY层用来指定BLE所用的无线频段(2.4G)&#xff0c;调制解调方式和方法、跳频等。PHY层的性能直接决定整个BLE芯片的功耗、灵敏度以及selectivity等射频指标。 LL层(Link Layer 链路层) 链路层主要是对RF射频控制。链路层定义了协议栈中最为基础的…

C++解决大学课设所有管理系统(增删查改)

C一篇解决大学课设所有**管理系统(增删查改) 文章目录 C一篇解决大学课设所有**管理系统(增删查改)1.引言1.1 使用结果展示 2. 基本原理3. 文件层次结构4.具体实现(通讯录管理系统为例)4.1 通讯录实体类(addressbook.h)4.2 通讯录实现类(addressbook.cpp)4.3 通讯录管理类&…