dl4j如何使用遗传神经网络完成手写数字识别

发布时间:2022-01-05 18:15:41 作者:柒染
来源:亿速云 阅读:151

今天就跟大家聊聊有关dl4j如何使用遗传神经网络完成手写数字识别,可能很多人都不太了解,为了让大家更加了解,小编给大家总结了以下内容,希望大家根据这篇文章可以有所收获。

实现步骤

这个过程就类似于自然界的优胜劣汰,将神经网络超参看作dna,超参的调整看作dna的突变;当然还可以把拥有不同隐藏层的神经网络看作不同的物种,让竞争过程更加多样化.当然我们这里只讨论一种神经网络的情况

优势: 可以解决很多没有头绪的问题 劣势: 训练效率极低

gitee地址:

https://gitee.com/ichiva/gnn.git

实现步骤 1.进化接口

public interface Evolution {

    /**
     * 遗传
     * @param mDna
     * @param fDna
     * @return
     */
    INDArray inheritance(INDArray mDna,INDArray fDna);

    /**
     * 突变
     * @param dna
     * @param v
     * @param r 突变范围
     * @return
     */
    INDArray mutation(INDArray dna,double v, double r);

    /**
     * 置换
     * @param dna
     * @param v
     * @return
     */
    INDArray substitution(INDArray dna,double v);

    /**
     * 外源
     * @param dna
     * @param v
     * @return
     */
    INDArray other(INDArray dna,double v);

    /**
     * DNA 是否同源
     * @param mDna
     * @param fDna
     * @return
     */
    boolean iSogeny(INDArray mDna, INDArray fDna);
}

一个比较通用的实现

public class MnistEvolution implements Evolution {

    private static final MnistEvolution instance = new MnistEvolution();

    public static MnistEvolution getInstance() {
        return instance;
    }

    @Override
    public INDArray inheritance(INDArray mDna, INDArray fDna) {
        if(mDna == fDna) return mDna;

        long[] mShape = mDna.shape();
        if(!iSogeny(mDna,fDna)){
            throw new RuntimeException("非同源dna");
        }

        INDArray nDna = Nd4j.create(mShape);
        NdIndexIterator it = new NdIndexIterator(mShape);
        while (it.hasNext()){
            long[] next = it.next();

            double val;
            if(Math.random() > 0.5){
                val = fDna.getDouble(next);
            }else {
                val = mDna.getDouble(next);
            }

            nDna.putScalar(next,val);
        }

        return nDna;
    }

    @Override
    public INDArray mutation(INDArray dna, double v, double r) {
        long[] shape = dna.shape();
        INDArray nDna = Nd4j.create(shape);
        NdIndexIterator it = new NdIndexIterator(shape);
        while (it.hasNext()) {
            long[] next = it.next();
            if(Math.random() < v){
                dna.putScalar(next,dna.getDouble(next) + ((Math.random() - 0.5) * r * 2));
            }else {
                nDna.putScalar(next,dna.getDouble(next));
            }
        }

        return nDna;
    }

    @Override
    public INDArray substitution(INDArray dna, double v) {
        long[] shape = dna.shape();
        INDArray nDna = Nd4j.create(shape);
        NdIndexIterator it = new NdIndexIterator(shape);
        while (it.hasNext()) {
            long[] next = it.next();
            if(Math.random() > v){
                long[] tag = new long[shape.length];
                for (int i = 0; i < shape.length; i++) {
                    tag[i] = (long) (Math.random() * shape[i]);
                }

                nDna.putScalar(next,dna.getDouble(tag));
            }else {
                nDna.putScalar(next,dna.getDouble(next));
            }
        }

        return nDna;
    }

    @Override
    public INDArray other(INDArray dna, double v) {
        long[] shape = dna.shape();
        INDArray nDna = Nd4j.create(shape);
        NdIndexIterator it = new NdIndexIterator(shape);
        while (it.hasNext()) {
            long[] next = it.next();
            if(Math.random() > v){
                nDna.putScalar(next,Math.random());
            }else {
                nDna.putScalar(next,dna.getDouble(next));
            }
        }

        return nDna;
    }

    @Override
    public boolean iSogeny(INDArray mDna, INDArray fDna) {
        long[] mShape = mDna.shape();
        long[] fShape = fDna.shape();

        if (mShape.length == fShape.length) {
            for (int i = 0; i < mShape.length; i++) {
                if (mShape[i] != fShape[i]) {
                    return false;
                }
            }

            return true;
        }

        return false;
    }
}

定义智能体配置接口

public interface AgentConfig {
    /**
     * 输入量
     * @return
     */
    int getInput();

    /**
     * 输出量
     * @return
     */
    int getOutput();

    /**
     * 神经网络配置
     * @return
     */
    MultiLayerConfiguration getMultiLayerConfiguration();
}

按手写数字识别进行配置实现

public class MnistConfig implements AgentConfig {
    @Override
    public int getInput() {
        return 28 * 28;
    }

    @Override
    public int getOutput() {
        return 10;
    }

    @Override
    public MultiLayerConfiguration getMultiLayerConfiguration() {
        return new NeuralNetConfiguration.Builder()
                .seed((long) (Math.random() * Long.MAX_VALUE))
                .updater(new Nesterovs(0.006, 0.9))
                .l2(1e-4)
                .list()
                .layer(0, new DenseLayer.Builder()
                        .nIn(getInput())
                        .nOut(1000)
                        .activation(Activation.RELU)
                        .weightInit(WeightInit.XAVIER)
                        .build())
                .layer(1, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD) //create hidden layer
                        .nIn(1000)
                        .nOut(getOutput())
                        .activation(Activation.SOFTMAX)
                        .weightInit(WeightInit.XAVIER)
                        .build())
                .pretrain(false).backprop(true)
                .build();
    }
}

智能体基类

@Getter
public class Agent {

    private final AgentConfig config;
    private final INDArray dna;
    private final MultiLayerNetwork multiLayerNetwork;

    /**
     * 采用默认方法初始化参数
     * @param config
     */
    public Agent(AgentConfig config){
        this(config,null);
    }


    /**
     *
     * @param config
     * @param dna
     */
    public Agent(AgentConfig config, INDArray dna){
        if(dna == null){
            this.config = config;
            MultiLayerConfiguration conf = config.getMultiLayerConfiguration();
            this.multiLayerNetwork = new MultiLayerNetwork(conf);
            multiLayerNetwork.init();
            this.dna = multiLayerNetwork.params();
        }else {
            this.config = config;
            MultiLayerConfiguration conf = config.getMultiLayerConfiguration();
            this.multiLayerNetwork = new MultiLayerNetwork(conf);
            multiLayerNetwork.init(dna,true);
            this.dna = dna;
        }
    }
}

手写数字智能体实现类

@Getter
@Setter
public class MnistAgent extends Agent {

    private static final AtomicInteger index = new AtomicInteger(0);

    private String name;

    /**
     * 环境适应分数
     */
    private double score;

    /**
     * 验证分数
     */
    private double validScore;

    public MnistAgent(AgentConfig config) {
        this(config,null);
    }

    public MnistAgent(AgentConfig config, INDArray dna) {
        super(config, dna);
        name = "agent-" + index.incrementAndGet();
    }

    public static MnistConfig mnistConfig = new MnistConfig();

    public static MnistAgent newInstance(){
        return new MnistAgent(mnistConfig);
    }

    public static MnistAgent create(INDArray dna){
        return new MnistAgent(mnistConfig,dna);
    }
}

手写数字识别环境构建

@Slf4j
public class MnistEnv {
    /**
     * 环境数据
     */
    private static final ThreadLocal<MnistDataSetIterator> tLocal = ThreadLocal.withInitial(() -> {
        try {
            return new MnistDataSetIterator(128, true, 0);
        } catch (IOException e) {
            throw new RuntimeException("mnist 文件读取失败");
        }
    });

    private static final ThreadLocal<MnistDataSetIterator> testLocal = ThreadLocal.withInitial(() -> {
        try {
            return new MnistDataSetIterator(128, false, 0);
        } catch (IOException e) {
            throw new RuntimeException("mnist 文件读取失败");
        }
    });

    private static final MnistEvolution evolution = MnistEvolution.getInstance();

    /**
     * 环境承载上限
     *
     * 超过上限AI会进行激烈竞争
     */
    private final int max;
    private Double maxScore,minScore;

    /**
     * 环境中的生命体
     *
     * 新生代与历史代共同排序,选出最适应环境的个体
     */
    //2个变量,一个队列保存KEY的顺序,一个MAP保存KEY对应的具体对象的数据  线程安全map
    private final TreeMap<Double,MnistAgent> lives = new TreeMap<>();

    /**
     * 初始化环境
     *
     * 1.向环境中初始化ai
     * 2.将初始化ai进行环境适应性测试,并排序
     * @param max
     */
    public MnistEnv(int max){
        this.max = max;

        for (int i = 0; i < max; i++) {
            MnistAgent agent = MnistAgent.newInstance();
            test(agent);
            synchronized (lives) {
                lives.put(agent.getScore(),agent);
            }
            log.info("初始化智能体 name = {} , score = {}",i,agent.getScore());
        }



        synchronized (lives) {
            minScore = lives.firstKey();
            maxScore = lives.lastKey();
        }
    }

    /**
     * 环境适应性评估
     * @param ai
     */
    public void test(MnistAgent ai){
        MultiLayerNetwork network = ai.getMultiLayerNetwork();
        MnistDataSetIterator dataIterator = tLocal.get();
        Evaluation eval = new Evaluation(ai.getConfig().getOutput());
        try {
            while (dataIterator.hasNext()) {
                DataSet data = dataIterator.next();
                INDArray output = network.output(data.getFeatures(), false);
                eval.eval(data.getLabels(),output);
            }
        }finally {
            dataIterator.reset();
        }
        ai.setScore(eval.accuracy());
    }

    /**
     * 迁移评估
     *
     * @param ai
     */
    public void validation(MnistAgent ai){
        MultiLayerNetwork network = ai.getMultiLayerNetwork();
        MnistDataSetIterator dataIterator = testLocal.get();
        Evaluation eval = new Evaluation(ai.getConfig().getOutput());
        try {
            while (dataIterator.hasNext()) {
                DataSet data = dataIterator.next();
                INDArray output = network.output(data.getFeatures(), false);
                eval.eval(data.getLabels(),output);
            }
        }finally {
            dataIterator.reset();
        }
        ai.setValidScore(eval.accuracy());
    }


    /**
     * 进化
     *
     * 每轮随机创建ai并放入环境中进行优胜劣汰
     * @param n 进化次数
     */
    public void evolution(int n){
        BlockThreadPool blockThreadPool=new BlockThreadPool(2);
        for (int i = 0; i < n; i++) {
            blockThreadPool.execute(() -> contend(newLive()));
        }

//        for (int i = 0; i < n; i++) {
//            contend(newLive());
//        }
    }

    /**
     * 竞争
     * @param ai
     */
    public void contend(MnistAgent ai){
        test(ai);
        quality(ai);
        double score = ai.getScore();
        if(score <= minScore){
            UI.put("无法生存",String.format("name = %s,  score = %s", ai.getName(),ai.getScore()));
            return;
        }

        Map.Entry<Double, MnistAgent> lastEntry;
        synchronized (lives) {
            lives.put(score,ai);
            if (lives.size() > max) {
                MnistAgent lastAI = lives.remove(lives.firstKey());
                UI.put("淘 汰 ",String.format("name = %s, score = %s", lastAI.getName(),lastAI.getScore()));
            }
            lastEntry = lives.lastEntry();
            minScore = lives.firstKey();
        }

        Double lastScore = lastEntry.getKey();
        if(lastScore > maxScore){
            maxScore = lastScore;
            MnistAgent agent = lastEntry.getValue();
            validation(agent);
            UI.put("max验证",String.format("score = %s,validScore = %s",lastScore,agent.getValidScore()));

            try {
                Warehouse.write(agent);
            } catch (IOException ex) {
                log.error("保存对象失败",ex);
            }
        }
    }

    ArrayList<Double> scoreList = new ArrayList<>(100);
    ArrayList<Integer> avgList = new ArrayList<>();
    private void quality(MnistAgent ai) {
        synchronized (scoreList) {
            scoreList.add(ai.getScore());
            if (scoreList.size() >= 100) {
                double avg = scoreList.stream().mapToDouble(e -> e)
                        .average().getAsDouble();
                avgList.add((int) (avg * 1000));
                StringBuffer buffer = new StringBuffer();
                avgList.forEach(e -> buffer.append(e).append('\t'));
                UI.put("平均得分",String.format("aix100 avg = %s",buffer.toString()));
                scoreList.clear();
            }
        }
    }

    /**
     * 随机生成新智能体
     *
     * 完全随机产生母本
     * 随机从比目标相同或更高评分中选择父本
     *
     * 基因进化在1%~10%之间进行,评分越高基于越稳定
     */
    public MnistAgent newLive(){
        double r = Math.random();
        //基因突变率
        double v = r / 11 + 0.01;
        //母本
        MnistAgent mAgent = getMother(r);
        //父本
        MnistAgent fAgent = getFather(r);

        int i = (int) (Math.random() * 3);
        INDArray newDNA = evolution.inheritance(mAgent.getDna(), fAgent.getDna());
        switch (i){
            case 0:
                newDNA = evolution.other(newDNA,v);
                break;
            case 1:
                newDNA = evolution.mutation(newDNA,v,0.1);
                break;
            case 2:
                newDNA = evolution.substitution(newDNA,v);
                break;
        }
        return MnistAgent.create(newDNA);
    }

    /**
     * 父本只选择比母本评分高的样本
     * @param r
     * @return
     */
    private MnistAgent getFather(double r) {
        r += (Math.random() * (1-r));
        return getMother(r);
    }

    private MnistAgent getMother(double r) {
        int index = (int) (r * max);
        return getMnistAgent(index);
    }

    private MnistAgent getMnistAgent(int index) {
        synchronized (lives) {
            Iterator<Map.Entry<Double, MnistAgent>> it = lives.entrySet().iterator();
            for (int i = 0; i < index; i++) {
                it.next();
            }

            return it.next().getValue();
        }
    }
}

主函数

@Slf4j
public class Program {

    public static void main(String[] args) {
        UI.put("开始时间",new Date().toLocaleString());
        MnistEnv env = new MnistEnv(128);
        env.evolution(Integer.MAX_VALUE);
    }
}

运行截图

  dl4j如何使用遗传神经网络完成手写数字识别

看完上述内容,你们对dl4j如何使用遗传神经网络完成手写数字识别有进一步的了解吗?如果还想了解更多知识或者相关内容,请关注亿速云行业资讯频道,感谢大家的支持。

推荐阅读:
  1. 使用Python如何实现简单遗传算法
  2. 如何使用MPI for Python并行化遗传算法

免责声明:本站发布的内容(图片、视频和文字)以原创、转载和分享为主,文章观点不代表本网站立场,如果涉及侵权请联系站长邮箱:is@yisu.com进行举报,并提供相关证据,一经查实,将立刻删除涉嫌侵权内容。

上一篇:Redis中服务端请求伪造SSRF的示例分析

下一篇:Gcc常用选项是什么

相关阅读

您好,登录后才能下订单哦!

密码登录
登录注册
其他方式登录
点击 登录注册 即表示同意《亿速云用户服务条款》