基于Java实现的一层简单人工神经网络算法示例

本文实例讲述了基于Java实现的一层简单人工神经网络算法。分享给大家供大家参考,具体如下:

先来看看笔者绘制的算法图:

2、数据类

import java.util.Arrays;
public class Data {
  double[] vector;
  int dimention;
  int type;
  public double[] getVector() {
    return vector;
  }
  public void setVector(double[] vector) {
    this.vector = vector;
  }
  public int getDimention() {
    return dimention;
  }
  public void setDimention(int dimention) {
    this.dimention = dimention;
  }
  public int getType() {
    return type;
  }
  public void setType(int type) {
    this.type = type;
  }
  public Data(double[] vector, int dimention, int type) {
    super();
    this.vector = vector;
    this.dimention = dimention;
    this.type = type;
  }
  public Data() {
  }
  @Override
  public String toString() {
    return "Data [vector=" + Arrays.toString(vector) + ", dimention=" + dimention + ", type=" + type + "]";
  }
}

3、简单人工神经网络

package cn.edu.hbut.chenjie;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;
import org.jfree.chart.ChartFactory;
import org.jfree.chart.ChartFrame;
import org.jfree.chart.JFreeChart;
import org.jfree.data.xy.DefaultXYDataset;
import org.jfree.ui.RefineryUtilities;
public class ANN2 {
  private double eta;//学习率
  private int n_iter;//权重向量w[]训练次数
  private List<Data> exercise;//训练数据集
  private double w0 = 0;//阈值
  private double x0 = 1;//固定值
  private double[] weights;//权重向量,其长度为训练数据维度+1,在本例中数据为2维,故长度为3
  private int testSum = 0;//测试数据总数
  private int error = 0;//错误次数
  DefaultXYDataset xydataset = new DefaultXYDataset();
  /**
   * 向图表中增加同类型的数据
   * @param type 类型
   * @param a 所有数据的第一个分量
   * @param b 所有数据的第二个分量
   */
  public void add(String type,double[] a,double[] b)
  {
    double[][] data = new double[2][a.length];
    for(int i=0;i<a.length;i++)
    {
      data[0][i] = a[i];
      data[1][i] = b[i];
    }
    xydataset.addSeries(type, data);
  }
  /**
   * 画图
   */
  public void draw()
  {
    JFreeChart jfreechart = ChartFactory.createScatterPlot("exercise", "x1", "x2", xydataset);
    ChartFrame frame = new ChartFrame("训练数据", jfreechart);
    frame.pack();
    RefineryUtilities.centerFrameOnScreen(frame);
    frame.setVisible(true);
  }
  public static void main(String[] args)
  {
    ANN2 ann2 = new ANN2(0.001,100);//构造人工神经网络
    List<Data> exercise = new ArrayList<Data>();//构造训练集
    //人工模拟1000条训练数据 ,分界线为x2=x1+0.5
    for(int i=0;i<1000000;i++)
    {
      Random rd = new Random();
      double x1 = rd.nextDouble();//随机产生一个分量
      double x2 = rd.nextDouble();//随机产生另一个分量
      double[] da = {x1,x2};//产生数据向量
      Data d = new Data(da, 2, x2 > x1+0.5 ? 1 : -1);//构造数据
      exercise.add(d);//将训练数据加入训练集
    }
    int sum1 = 0;//记录类型1的训练记录数
    int sum2 = 0;//记录类型-1的训练记录数
    for(int i = 0; i < exercise.size(); i++)
    {
      if(exercise.get(i).getType()==1)
        sum1++;
      else if(exercise.get(i).getType()==-1)
        sum2++;
    }
    double[] x1 = new double[sum1];
    double[] y1 = new double[sum1];
    double[] x2 = new double[sum2];
    double[] y2 = new double[sum2];
    int index1 = 0;
    int index2 = 0;
    for(int i = 0; i < exercise.size(); i++)
    {
      if(exercise.get(i).getType()==1)
      {
        x1[index1] = exercise.get(i).vector[0];
        y1[index1++] = exercise.get(i).vector[1];
      }
      else if(exercise.get(i).getType()==-1)
      {
        x2[index2] = exercise.get(i).vector[0];
        y2[index2++] = exercise.get(i).vector[1];
      }
    }
    ann2.add("1", x1, y1);
    ann2.add("-1", x2, y2);
    ann2.draw();
    ann2.input(exercise);//将训练集输入人工神经网络
    ann2.fit();//训练
    ann2.showWeigths();//显示权重向量
    //人工生成一千条测试数据
    for(int i=0;i<10000;i++)
    {
      Random rd = new Random();
      double x1_ = rd.nextDouble();
      double x2_ = rd.nextDouble();
      double[] da = {x1_,x2_};
      Data test = new Data(da, 2, x2_ > x1_+0.5 ? 1 : -1);
      ann2.predict(test);//测试
    }
    System.out.println("总共测试" + ann2.testSum + "条数据,有" + ann2.error + "条错误,错误率:" + ann2.error * 1.0 /ann2.testSum * 100 + "%");
  }
  /**
   *
   * @param eta 学习率
   * @param n_iter 权重分量学习次数
   */
  public ANN2(double eta, int n_iter) {
    this.eta = eta;
    this.n_iter = n_iter;
  }
  /**
   * 输入训练集到人工神经网络
   * @param exercise
   */
  private void input(List<Data> exercise) {
    this.exercise = exercise;//保存训练集
    weights = new double[exercise.get(0).dimention + 1];//初始化权重向量,其长度为训练数据维度+1
    weights[0] = w0;//权重向量第一个分量为w0
    for(int i = 1; i < weights.length; i++)
      weights[i] = 0;//其余分量初始化为0
  }
  private void fit() {
    for(int i = 0; i < n_iter; i++)//权重分量调整n_iter次
    {
      for(int j = 0; j < exercise.size(); j++)//对于训练集中的每条数据进行训练
      {
        int real_result = exercise.get(j).type;//y
        int calculate_result = CalculateResult(exercise.get(j));//y'
        double delta0 = eta * (real_result - calculate_result);//计算阈值更新
        w0 += delta0;//阈值更新
        weights[0] = w0;//更新w[0]
        for(int k = 0; k < exercise.get(j).getDimention(); k++)//更新权重向量其它分量
        {
          double delta = eta * (real_result - calculate_result) * exercise.get(j).vector[k];
          //Δw=η*(y-y')*X
          weights[k+1] += delta;
          //w=w+Δw
        }
      }
    }
  }
  private int CalculateResult(Data data) {
    double z = w0 * x0;
    for(int i = 0; i < data.dimention; i++)
      z += data.vector[i] * weights[i+1];
    //z=w0x0+w1x1+...+WmXm
    //激活函数
    if(z>=0)
      return 1;
    else
      return -1;
  }
  private void showWeigths()
  {
    for(double w : weights)
      System.out.println(w);
  }
  private void predict(Data data) {
    int type = CalculateResult(data);
    if(type == data.getType())
    {
      //System.out.println("预测正确");
    }
    else
    {
      //System.out.println("预测错误");
      error ++;
    }
    testSum ++;
  }
}

运行结果:

-0.22000000000000017
-0.4416843982815453
0.442444202054685
总共测试10000条数据,有17条错误,错误率:0.16999999999999998%

更多关于java算法相关内容感兴趣的读者可查看本站专题:《Java数据结构与算法教程》、《Java操作DOM节点技巧总结》、《Java文件与目录操作技巧汇总》和《Java缓存操作技巧汇总》

希望本文所述对大家java程序设计有所帮助。

(0)

相关推荐

  • Java语言实现Blowfish加密算法完整代码分享

    前几天网上突然出现流言:某东发生数据泄露12G,最终某东在一篇声明中没有否认,还算是勉强承认了吧,这件事对于一般人有什么影响.应该怎么做已经有一堆人说了,所以就不凑热闹了,咱来点对程序猿来说实际点的,说一个个人认为目前比较安全的加密算法:Blowfish. 上代码之前,先说几点Blowfish加密算法的特点: 1. 对称加密,即加密的密钥和解密的密钥是相同的: 2. 每次加密之后的结果是不同的(这也是老夫比较欣赏的一点): 3. 可逆的,和老夫之前的文章介绍的md5等摘要算法不一样,他是可逆的:

  • Java编程实现A*算法完整代码

    前言 A*搜寻算法俗称A星算法.这是一种在图形平面上,有多个节点的路径,求出最低通过成本的算法.常用于游戏中 通过二维数组构建的一个迷宫,"%"表示墙壁,A为起点,B为终点,"#"代表障碍物,"*"代表算法计算后的路径 本文实例代码结构: % % % % % % % % o o o o o % % o o # o o % % A o # o B % % o o # o o % % o o o o o % % % % % % % % =======

  • Java语言实现快速幂取模算法详解

    快速幂取模算法的引入是从大数的小数取模的朴素算法的局限性所提出的,在朴素的方法中我们计算一个数比如5^1003%31是非常消耗我们的计算资源的,在整个计算过程中最麻烦的就是我们的5^1003这个过程 缺点1:在我们在之后计算指数的过程中,计算的数字不都拿得增大,非常的占用我们的计算资源(主要是时间,还有空间) 缺点2:我们计算的中间过程数字大的恐怖,我们现有的计算机是没有办法记录这么长的数据的,所以说我们必须要想一个更加高效的方法来解决这个问题 当我们计算AB%C的时候,最便捷的方法就是调用Ma

  • Java矩阵连乘问题(动态规划)算法实例分析

    本文实例讲述了Java矩阵连乘问题(动态规划)算法.分享给大家供大家参考,具体如下: 问题描述:给定n个矩阵:A1,A2,...,An,其中Ai与Ai+1是可乘的,i=1,2...,n-1.确定计算矩阵连乘积的计算次序,使得依此次序计算矩阵连乘积需要的数乘次数最少.输入数据为矩阵个数和每个矩阵规模,输出结果为计算矩阵连乘积的计算次序和最少数乘次数. 问题解析:由于矩阵乘法满足结合律,故计算矩阵的连乘积可以有许多不同的计算次序.这种计算次序可以用加括号的方式来确定.若一个矩阵连乘积的计算次序完全确

  • Java基于栈方式解决汉诺塔问题实例【递归与非递归算法】

    本文实例讲述了Java基于栈方式解决汉诺塔问题.分享给大家供大家参考,具体如下: /** * 栈方式非递归汉诺塔 * @author zy * */ public class StackHanoi { /** * @param args */ public static void main(String[] args) { System.out.println("我们测试结果:"); System.out.println("递归方式:"); hanoiNormal(

  • Java基于分治算法实现的棋盘覆盖问题示例

    本文实例讲述了Java基于分治算法实现的棋盘覆盖问题.分享给大家供大家参考,具体如下: 在一个2^k * 2^k个方格组成的棋盘中,有一个方格与其它的不同,若使用以下四种L型骨牌覆盖除这个特殊方格的其它方格,如何覆盖.四个L型骨牌如下图: 棋盘中的特殊方格如图: 实现的基本原理是将2^k * 2^k的棋盘分成四块2^(k - 1) * 2^(k - 1)的子棋盘,特殊方格一定在其中的一个子棋盘中,如果特殊方格在某一个子棋盘中,继续递归处理这个子棋盘,直到这个子棋盘中只有一个方格为止如果特殊方格不

  • Java基于递归和循环两种方式实现未知维度集合的笛卡尔积算法示例

    本文实例讲述了Java基于递归和循环两种方式实现未知维度集合的笛卡尔积.分享给大家供大家参考,具体如下: 什么是笛卡尔积? 在数学中,两个集合X和Y的笛卡儿积(Cartesian product),又称直积,表示为X × Y,第一个对象是X的成员而第二个对象是Y的所有可能有序对的其中一个成员. 假设集合A={a,b},集合B={0,1,2},则两个集合的笛卡尔积为{(a,0),(a,1),(a,2),(b,0),(b,1), (b,2)}. 如何用程序算法实现笛卡尔积? 如果编程前已知集合的数量

  • 70行Java代码实现深度神经网络算法分享

    对于现在流行的深度学习,保持学习精神是必要的--程序员尤其是架构师永远都要对核心技术和关键算法保持关注和敏感,必要时要动手写一写掌握下来,先不用关心什么时候用到--用不用是政治问题,会不会写是技术问题,就像军人不关心打不打的问题,而要关心如何打赢的问题. 程序员如何学习机器学习 对程序员来说,机器学习是有一定门槛的(这个门槛也是其核心竞争力),相信很多人在学习机器学习时都会为满是数学公式的英文论文而头疼,甚至可能知难而退.但实际上机器学习算法落地程序并不难写,下面是70行代码实现的反向多层(BP

  • 基于Java实现的图的广度优先遍历算法

    本文以实例形式讲述了基于Java的图的广度优先遍历算法实现方法,具体方法如下: 用邻接矩阵存储图方法: 1.确定图的顶点个数和边的个数 2.输入顶点信息存储在一维数组vertex中 3.初始化邻接矩阵: 4.依次输入每条边存储在邻接矩阵arc中 输入边依附的两个顶点的序号i,j: 将邻接矩阵的第i行第j列的元素值置为1: 将邻接矩阵的第j行第i列的元素值置为1: 广度优先遍历实现: 1.初始化队列Q 2.访问顶点v:visited[v]=1;顶点v入队Q; 3.while(队列Q非空) v=队列

  • Java实现DFA算法对敏感词、广告词过滤功能示例

    一.前言 开发中经常要处理用户一些文字的提交,所以涉及到了敏感词过滤的功能,参考资料中DFA有穷状态机算法的实现,创建有向图.完成了对敏感词.广告词的过滤,而且效率较好,所以分享一下. 具体实现: 1.匹配大小写过滤  2.匹配全角半角过滤  3.匹配过滤停顿词过滤.  4.敏感词重复词过滤. 例如: 支持如下类型类型过滤检测: fuck 全小写 FuCk 大小写 fuck全角半角 f!!!u&c ###k 停顿词 fffuuuucccckkk 重复词 敏感词过滤的做法有很多,我简单描述我现在理

  • 关于JAVA经典算法40题(超实用版)

    [程序1]题目:古典问题:有一对兔子,从出生后第3个月起每个月都生一对兔子,小兔子长到第四个月后每个月又生一对兔子,假如兔子都不死,问每个月的兔子总数为多少?1.程序分析: 兔子的规律为数列1,1,2,3,5,8,13,21....public class exp2{ public static void main(String args[]){ int i=0; for(i=1;i<=20;i++)System.out.println(f(i));}public static int f(in

随机推荐