博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
Pytorch框架学习(10)——损失函数
阅读量:2238 次
发布时间:2019-05-09

本文共 2732 字,大约阅读时间需要 9 分钟。

文章目录

1. 损失函数概念

  • 损失函数:衡量模型输出与真实标签的差异
    在这里插入图片描述
  • 损失函数(Loss Function):
    L o s s = f ( y ^ , y ) Loss = f(\hat{y}, y) Loss=f(y^,y)
  • 代价函数(Cost Function):
    L o s s = 1 N ∑ i N f ( y i ^ , y i ) Loss = \frac{1}{N}\sum^{N}_{i} f(\hat{y_i}, y_i) Loss=N1iNf(yi^,yi)
  • 目标函数(Objective Function):
    O b j = C o s t + R e g u l a r i z a t i o n ( 正 则 项 ) Obj = Cost + Regularization(正则项) Obj=Cost+Regularization()

在这里插入图片描述

2. 交叉熵损失函数

  • 1.nn.CrossEntropyLoss
    • 功能:nn.LogSoftmax()与nn.NLLLoss()结合,进行交叉熵计算
    • 主要参数:
      • weight:各类别的loss设置权值
      • ignore_index:忽略某个类别
      • reduction:计算模式,可为none/sum/mean
        • none:逐个元素计算
        • sum:所有元素求和,返回标量
        • mean:加权平均,返回标量

3. NLL/BCE/BCEWITHLogits Loss

  • 2.nn.NLLLoss

    • 功能:实现负对数似然函数的负号功能
    • 主要参数:
      • weight:各类别的loss设置权值
      • ignore_index:忽略某个类别
      • reduction:计算模式,可为none/sum/mean
        • none:逐个元素计算
        • sum:所有元素求和,返回标量
        • mean:加权平均,返回标量
  • 3.nn.BCELoss

    • 功能:二分类交叉熵,输入值取值在[0,1]
    • 主要参数:
      • weight:各类别的loss设置权值
      • ignore_index:忽略某个类别
      • reduction:计算模式,可为none/sum/mean
        • none:逐个元素计算
        • sum:所有元素求和,返回标量
        • mean:加权平均,返回标量
  • 4.BCEWITHLogits Loss

    • 功能:结合Sigmoid与二分类交叉熵,网络最后不加sigmoid函数
    • 主要参数:
      • pos_weight:正样本的权值
      • weight:各类别的loss设置权值
      • ignore_index:忽略某个类别
      • reduction:计算模式,可为none/sum/mean
        • none:逐个元素计算
        • sum:所有元素求和,返回标量
        • mean:加权平均,返回标量

数据回归模型中常用的损失函数:

  • 5.nn.L1Loss
    • 功能:计算inputs与target之差的绝对值
    • 公式: l n = ∣ x n − y n ∣ l_n = |x_n - y_n| ln=xnyn
  • 6.nn.MSELoss
    • 功能:计算inputs与target之差的平方
    • 公式: l n = ( x n − y n ) 2 l_n = (x_n - y_n)^2 ln=(xnyn)2

两个损失函数的主要参数为:

  • reduction:计算模式,可为none/sum/mean

    - none:逐个元素计算
    - sum:所有元素求和,返回标量
    - mean:加权平均,返回标量

  • 7.SmoothL1Loss

    • 功能:平滑的L1Loss
    • 参数:
      • reduction:计算模式,可为none/sum/mean
        • none:逐个元素计算
        • sum:所有元素求和,返回标量
        • mean:加权平均,返回标量
          在这里插入图片描述
          在这里插入图片描述
  • 8.PoissonNLLLoss

    • 功能:泊松分布的负对数似然损失函数
    • 主要参数:
      • log_input:输入是否为对数形式,决定计算公式
      • full:计算所有loss,默认为False
      • eps:修正项,避免log(input)为nan
        在这里插入图片描述
  • 9.nn.KLDivLoss

    • 功能:计算KLD(divergence),KL散度,相对熵
    • 注意:需提前将输入计算log-probabilities, 如通过nn.logsoftmax()
    • 主要参数:
      • reduction:计算模式,可为none/sum/mean/batchmean
        • batchmean:batchsize维度求平均值
        • none:逐个元素计算
        • sum:所有元素求和,返回标量
        • mean:加权平均,返回标量
          在这里插入图片描述
  • 10.nn.MarginRankingLoss

    • 功能:计算两个向量之间的相似度,用于排序任务
    • 特别说明:该方法计算两组数据之间的差异,返回一个n*n的loss矩阵
    • 主要参数:
      • margin:边界值,x1与x2之间的差异值
      • reduction:计算模式,可为none/sum/mean
        在这里插入图片描述
  • 11.nn.MultiLabelMarginLoss

    • 功能:多标签边界损失函数
    • 主要参数:
      • reduction:计算模式
    • 示例:四分类任务,样本x输入0类和4类,标签[0,3,-1,-1],不是[1,0,0,1]
      在这里插入图片描述
  • 12.nn.SoftMarginLoss

    • 功能:计算二分类的logistic损失
    • 参数:reduction:计算模式
      在这里插入图片描述
  • 13.nn.MultiLabelSoftMarginLoss

    • 功能:SoftMarginLoss多标签版本
    • 参数:
      • weight:各类别的loss设置权值
      • reduction:计算模式。
        在这里插入图片描述
  • 14.nn.MultiMarginLoss

    • 功能:计算多分类的折页损失
    • 参数:
      • p:可选1或2
      • weight:各类别的loss设置权值
      • margin:边界值
      • reduction:计算模式
        在这里插入图片描述
        在这里插入图片描述
  • 15.nn.TripletMarginLoss

    • 功能:计算三元组损失,人脸验证中常用
    • 主要参数:
      • p:范数的阶,默认为2
      • margin:边界值
      • reduction:计算模式
        在这里插入图片描述
  • 16.nn.HingeEmbeddingLoss

    • 功能:计算两个输入的相似性,常用于非线性embedding和半监督学习
    • 注意:输入x应为两个输入之差的绝对值
    • 主要参数:
      • margin:边界值
      • reduction:计算模式
        在这里插入图片描述
  • 17.nn.CosineEmbeddingLoss

    • 功能:采用余弦相似度计算两个输入的相似性
    • 主要参数:
      • margin:可取值[-1, 1],推荐为[0, 0.5]
      • reduction: 计算模式
        在这里插入图片描述
  • 18.nn.CTCLoss

    • 功能:计算CTC损失,解决时序类数据的分类
    • 主要参数:
      • blank:blank label
      • zero_infinity:无穷大的值或梯度置0
      • reduction:计算模式

转载地址:http://yhobb.baihongyu.com/

你可能感兴趣的文章
java操作cookie 实现两周内自动登录
查看>>
Tomcat 7优化前及优化后的性能对比
查看>>
Java Guava中的函数式编程讲解
查看>>
Eclipse Memory Analyzer 使用技巧
查看>>
tomcat连接超时
查看>>
谈谈编程思想
查看>>
iOS MapKit导航及地理转码辅助类
查看>>
检测iOS的网络可用性并打开网络设置
查看>>
简单封装FMDB操作sqlite的模板
查看>>
iOS开发中Instruments的用法
查看>>
强引用 软引用 弱引用 虚引用
查看>>
数据类型 java转换
查看>>
"NetworkError: 400 Bad Request - http://172.16.47.117:8088/rhip/**/####t/approval?date=976
查看>>
mybatis 根据 数据库表 自动生成 实体
查看>>
C结构体、C++结构体、C++类的区别
查看>>
进程和线程的概念、区别和联系
查看>>
CMake 入门实战
查看>>
绑定CPU逻辑核心的利器——taskset
查看>>
Linux下perf性能测试火焰图只显示函数地址不显示函数名的问题
查看>>
c结构体、c++结构体和c++类的区别以及错误纠正
查看>>