C++中TensorflowLite模型验证的过程

发布时间:2021-08-27 08:58:36 作者:chen
来源:亿速云 阅读:184

这篇文章主要介绍“C++中TensorflowLite模型验证的过程”,在日常操作中,相信很多人在C++中TensorflowLite模型验证的过程问题上存在疑惑,小编查阅了各式资料,整理出简单好用的操作方法,希望对大家解答”C++中TensorflowLite模型验证的过程”的疑惑有所帮助!接下来,请跟着小编一起来学习吧!

故事是这样的:

有一个手撑检测的tflite模型,需要在开发板上跑起来。手机版本的已成熟,要移植到开发板上。现在要验证tflite模型文件在板子上的运行结果要和手机上一致。

前提:为了多次重复测试,在Android端使用了同一帧数据(从一个录制的mp4中固定取一张图)测试代码如下图

C++中TensorflowLite模型验证的过程

下面是测试过程 

记录下Android版API运行推理前的图片数据文件(经过了规一化处理,所以都是-1~1之间的float数据)

C++中TensorflowLite模型验证的过程

这一步卡在了写float数据到二进制文件中,C++读出来有问题换了个方案,直接存储float字符串

private void saveFile(float[] pfImageData) {
        try {
            File file = new File(Environment.getExternalStoragePublicDirectory(Environment.DIRECTORY_DOWNLOADS).getAbsolutePath() + "/tfimg");
 
            StringBuilder sb = new StringBuilder();
            for (float val : pfImageData) {
                //保留4位小数,这里可以改为其他值
                sb.append(String.format("%.4f", val));
                sb.append("\r\n");
            }
 
            FileWriter out = new FileWriter(file);  //文件写入流
            out.write(sb.toString());
            out.close();
        } catch (Exception e) {
            e.printStackTrace();
            Log.e("Melon", "存储文件异常," + e.getMessage());
        }
    }

拿着这个文件在板子上输入到Tflite模型中

测试代码,主要是RunInference()和read_file()

/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
    http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
 
#include "tensorflow/lite/examples/label_image/label_image.h"
 
#include <fcntl.h>     // NOLINT(build/include_order)
#include <getopt.h>    // NOLINT(build/include_order)
#include <sys/time.h>  // NOLINT(build/include_order)
#include <sys/types.h> // NOLINT(build/include_order)
#include <sys/uio.h>   // NOLINT(build/include_order)
#include <unistd.h>    // NOLINT(build/include_order)
 
#include <cstdarg>
#include <cstdio>
#include <cstdlib>
#include <fstream>
#include <iomanip>
#include <iostream>
#include <map>
#include <memory>
#include <sstream>
#include <string>
#include <unordered_set>
#include <vector>
 
#include "absl/memory/memory.h"
#include "tensorflow/lite/examples/label_image/bitmap_helpers.h"
#include "tensorflow/lite/examples/label_image/get_top_n.h"
#include "tensorflow/lite/examples/label_image/log.h"
#include "tensorflow/lite/kernels/register.h"
#include "tensorflow/lite/optional_debug_tools.h"
#include "tensorflow/lite/profiling/profiler.h"
#include "tensorflow/lite/string_util.h"
#include "tensorflow/lite/tools/command_line_flags.h"
#include "tensorflow/lite/tools/delegates/delegate_provider.h"
 
namespace tflite
{
  namespace label_image
  {
 
    double get_us(struct timeval t) { return (t.tv_sec * 1000000 + t.tv_usec); }
 
    using TfLiteDelegatePtr = tflite::Interpreter::TfLiteDelegatePtr;
    using ProvidedDelegateList = tflite::tools::ProvidedDelegateList;
 
    class DelegateProviders
    {
    public:
      DelegateProviders() : delegate_list_util_(&params_)
      {
        delegate_list_util_.AddAllDelegateParams();
      }
 
      // Initialize delegate-related parameters from parsing command line arguments,
      // and remove the matching arguments from (*argc, argv). Returns true if all
      // recognized arg values are parsed correctly.
      bool InitFromCmdlineArgs(int *argc, const char **argv)
      {
        std::vector<tflite::Flag> flags;
        // delegate_list_util_.AppendCmdlineFlags(&flags);
 
        const bool parse_result = Flags::Parse(argc, argv, flags);
        if (!parse_result)
        {
          std::string usage = Flags::Usage(argv[0], flags);
          LOG(ERROR) << usage;
        }
        return parse_result;
      }
 
      // According to passed-in settings `s`, this function sets corresponding
      // parameters that are defined by various delegate execution providers. See
      // lite/tools/delegates/README.md for the full list of parameters defined.
      void MergeSettingsIntoParams(const Settings &s)
      {
        // Parse settings related to GPU delegate.
        // Note that GPU delegate does support OpenCL. 'gl_backend' was introduced
        // when the GPU delegate only supports OpenGL. Therefore, we consider
        // setting 'gl_backend' to true means using the GPU delegate.
        if (s.gl_backend)
        {
          if (!params_.HasParam("use_gpu"))
          {
            LOG(WARN) << "GPU deleate execution provider isn't linked or GPU "
                         "delegate isn't supported on the platform!";
          }
          else
          {
            params_.Set<bool>("use_gpu", true);
            // The parameter "gpu_inference_for_sustained_speed" isn't available for
            // iOS devices.
            if (params_.HasParam("gpu_inference_for_sustained_speed"))
            {
              params_.Set<bool>("gpu_inference_for_sustained_speed", true);
            }
            params_.Set<bool>("gpu_precision_loss_allowed", s.allow_fp16);
          }
        }
 
        // Parse settings related to NNAPI delegate.
        if (s.accel)
        {
          if (!params_.HasParam("use_nnapi"))
          {
            LOG(WARN) << "NNAPI deleate execution provider isn't linked or NNAPI "
                         "delegate isn't supported on the platform!";
          }
          else
          {
            params_.Set<bool>("use_nnapi", true);
            params_.Set<bool>("nnapi_allow_fp16", s.allow_fp16);
          }
        }
 
        // Parse settings related to Hexagon delegate.
        if (s.hexagon_delegate)
        {
          if (!params_.HasParam("use_hexagon"))
          {
            LOG(WARN) << "Hexagon deleate execution provider isn't linked or "
                         "Hexagon delegate isn't supported on the platform!";
          }
          else
          {
            params_.Set<bool>("use_hexagon", true);
            params_.Set<bool>("hexagon_profiling", s.profiling);
          }
        }
 
        // Parse settings related to XNNPACK delegate.
        if (s.xnnpack_delegate)
        {
          if (!params_.HasParam("use_xnnpack"))
          {
            LOG(WARN) << "XNNPACK deleate execution provider isn't linked or "
                         "XNNPACK delegate isn't supported on the platform!";
          }
          else
          {
            params_.Set<bool>("use_xnnpack", true);
            params_.Set<bool>("num_threads", s.number_of_threads);
          }
        }
      }
 
      // Create a list of TfLite delegates based on what have been initialized (i.e.
      // 'params_').
      std::vector<ProvidedDelegateList::ProvidedDelegate> CreateAllDelegates()
          const
      {
        return delegate_list_util_.CreateAllRankedDelegates();
      }
 
    private:
      // Contain delegate-related parameters that are initialized from command-line
      // flags.
      tflite::tools::ToolParams params_;
 
      // A helper to create TfLite delegates.
      ProvidedDelegateList delegate_list_util_;
    };
 
    // Takes a file name, and loads a list of labels from it, one per line, and
    // returns a vector of the strings. It pads with empty strings so the length
    // of the result is a multiple of 16, because our model expects that.
 
    // std::vector<uint8_t> read_file(const std::string &input_bmp_name)
    // {
    //   int begin, end;
 
    //   std::ifstream file(input_bmp_name, std::ios::in | std::ios::binary);
    //   if (!file)
    //   {
    //     LOG(FATAL) << "input file " << input_bmp_name << " not found";
    //     exit(-1);
    //   }
 
    //   begin = file.tellg();
    //   file.seekg(0, std::ios::end);
    //   end = file.tellg();
    //   size_t len = end - begin;
 
    //   LOG(INFO) << "len: " << len;
    //   std::vector<uint8_t> img_bytes(len);
 
    //   file.seekg(0, std::ios::beg);
    //   file.read(reinterpret_cast<char *>(img_bytes.data()), len);
 
    //   return img_bytes;
    // }
 
    /**
     * 读取文件
     */
    std::vector<float> read_file(const std::string &input_bmp_name)
    {
      int begin, end;
 
      std::ifstream file(input_bmp_name, std::ios::in | std::ios::binary);
      if (!file)
      {
        LOG(FATAL) << "input file " << input_bmp_name << " not found";
        exit(-1);
      }
 
      begin = file.tellg();
      file.seekg(0, std::ios::end);
      end = file.tellg();
      size_t len = end - begin;
 
      LOG(INFO) << "len: " << len;
      std::vector<float> img_bytes;
 
      file.seekg(0, std::ios::beg);
 
      string strLine = "";
      float temp;
      while (getline(file, strLine))
      {
        temp = atof(strLine.c_str());
        img_bytes.push_back(temp);
      }
 
      LOG(INFO) << "文件读取完成:" << input_bmp_name;
      return img_bytes;
    }
 
    /**
     * 运行推理
     */
    void RunInference(Settings *settings)
    {
      if (!settings->model_name.c_str())
      {
        LOG(ERROR) << "no model file name";
        exit(-1);
      }
 
      std::unique_ptr<tflite::FlatBufferModel> model;
      std::unique_ptr<tflite::Interpreter> interpreter;
      model = tflite::FlatBufferModel::BuildFromFile(settings->model_name.c_str());
      if (!model)
      {
        LOG(ERROR) << "Failed to mmap model " << settings->model_name;
        exit(-1);
      }
      settings->model = model.get();
      LOG(INFO) << "Loaded model " << settings->model_name;
      model->error_reporter();
      LOG(INFO) << "resolved reporter";
 
      tflite::ops::builtin::BuiltinOpResolver resolver;
 
      tflite::InterpreterBuilder(*model, resolver)(&interpreter); //生成interpreter
      if (!interpreter)
      {
        LOG(ERROR) << "Failed to construct interpreter";
        exit(-1);
      }
 
      interpreter->SetAllowFp16PrecisionForFp32(settings->allow_fp16);
 
      if (settings->verbose)
      {
        LOG(INFO) << "tensors size: " << interpreter->tensors_size();
        LOG(INFO) << "nodes size: " << interpreter->nodes_size();
        LOG(INFO) << "inputs: " << interpreter->inputs().size();
        LOG(INFO) << "input(0) name: " << interpreter->GetInputName(0);
 
        int t_size = interpreter->tensors_size();
        for (int i = 0; i < t_size; i++)
        {
          if (interpreter->tensor(i)->name)
            LOG(INFO) << i << ": " << interpreter->tensor(i)->name << ", "
                      << interpreter->tensor(i)->bytes << ", "
                      << interpreter->tensor(i)->type << ", "
                      << interpreter->tensor(i)->params.scale << ", "
                      << interpreter->tensor(i)->params.zero_point;
        }
      }
 
      if (settings->number_of_threads != -1)
      {
        interpreter->SetNumThreads(settings->number_of_threads);
      }
 
      int image_width = 128;
      int image_height = 128;
      int image_channels = 3;
      // std::vector<uint8_t> in = read_bmp(settings->input_bmp_name, &image_width, &image_height, &image_channels, settings);
      std::vector<float> file_bytes = read_file(settings->input_bmp_name);
      for (int i = 0; i < 100; i++)
      {
        //和Android的输入做对比
        LOG(INFO) << i << ": " << file_bytes[i];
      }
 
      /* inputs()[0]得到输入张量数组中的第一个张量,也就是classifier中唯一的那个输入张量;
      input是个整型值,是张量列表中的引索 */
      int input = interpreter->inputs()[0];
      LOG(INFO) << "input: " << input;
 
      const std::vector<int> inputs = interpreter->inputs();
      const std::vector<int> outputs = interpreter->outputs();
 
      LOG(INFO) << "number of inputs: " << inputs.size();
      LOG(INFO) << "input index: " << inputs[0];
      LOG(INFO) << "number of outputs: " << outputs.size();
      LOG(INFO) << "outputs index1: " << outputs[0] << ",outputs index2: " << outputs[1];
 
      if (interpreter->AllocateTensors() != kTfLiteOk)
      { //加载所有tensor
        LOG(ERROR) << "Failed to allocate tensors!";
        exit(-1);
      }
 
      if (settings->verbose)
        PrintInterpreterState(interpreter.get());
 
      // 从输入张量的原数据中得到输入尺寸
      TfLiteIntArray *dims = interpreter->tensor(input)->dims;
      int wanted_height = dims->data[1];
      int wanted_width = dims->data[2];
      int wanted_channels = dims->data[3];
 
      settings->input_type = interpreter->tensor(input)->type;
 
      //typed_tensor返回一个经过固定数据类型转换的tensor指针
      //以input为索引,在TfLiteTensor* content_.tensors这个张量表得到具体的张量
      //返回该张量的data.raw,它指示张量正关联着的内存块
      // resize<float>(interpreter->typed_tensor<float>(input), in.data(),
      //               image_height, image_width, image_channels, wanted_height,
      //               wanted_width, wanted_channels, settings);
 
      //赋值给input tensor
      float *inputP = interpreter->typed_input_tensor<float>(0);
 
      LOG(INFO) << "file_bytes size: " << file_bytes.size();
      for (int i = 0; i < file_bytes.size(); i++)
      {
        inputP[i] = file_bytes[i];
      }
 
      struct timeval start_time, stop_time;
      gettimeofday(&start_time, nullptr);
      for (int i = 0; i < settings->loop_count; i++)
      { //调用模型进行推理
        if (interpreter->Invoke() != kTfLiteOk)
        {
          LOG(ERROR) << "Failed to invoke tflite!";
          exit(-1);
        }
      }
      gettimeofday(&stop_time, nullptr);
      LOG(INFO) << "invoked";
      LOG(INFO) << "average time: "
                << (get_us(stop_time) - get_us(start_time)) /
                       (settings->loop_count * 1000)
                << " ms";
 
      const float threshold = 0.001f;
 
      int output = interpreter->outputs()[1];
      LOG(INFO) << "output: " << output;
      LOG(INFO) << "interpreter->tensors_size: " << interpreter->tensors_size();
 
      TfLiteTensor *tensor = interpreter->tensor(output);
 
      TfLiteIntArray *output_dims = tensor->dims;
      // assume output dims to be something like (1, 1, ... ,size)
      auto output_size = output_dims->data[output_dims->size - 1];
      LOG(INFO) << "索引为" << output << "的输出张量的-"
                << "output_size: " << output_size;
 
      for (int i = 0; i < output_dims->size; i++)
      {
        LOG(INFO) << "元数据有:" << output_dims->data[i];
      }
 
      float *prediction = interpreter->typed_output_tensor<float>(1);
 
      float classificators[1][896][1];
      memcpy(classificators, prediction, 896 * 1 * sizeof(float));
      // float classificators[1][896][18];
      // memcpy(classificators, prediction, 896 * 18 * sizeof(float));
 
      //输出分类结果
      for (float(&r)[896][1] : classificators)
      {
        for (float(&p)[1] : r)
        {
          for (float &q : p)
          {
            std::cout << q << ' ';
          }
          std::cout << std::endl;
        }
        std::cout << std::endl;
      }
    }
 
    int Main(int argc, char **argv)
    {
      DelegateProviders delegate_providers;
      bool parse_result = delegate_providers.InitFromCmdlineArgs(
          &argc, const_cast<const char **>(argv));
      if (!parse_result)
      {
        return EXIT_FAILURE;
      }
 
      Settings s;
 
      int c;
      while (true)
      {
        static struct option long_options[] = {
            {"accelerated", required_argument, nullptr, 'a'},
            {"allow_fp16", required_argument, nullptr, 'f'},
            {"count", required_argument, nullptr, 'c'},
            {"verbose", required_argument, nullptr, 'v'},
            {"image", required_argument, nullptr, 'i'},
            {"labels", required_argument, nullptr, 'l'},
            {"tflite_model", required_argument, nullptr, 'm'},
            {"profiling", required_argument, nullptr, 'p'},
            {"threads", required_argument, nullptr, 't'},
            {"input_mean", required_argument, nullptr, 'b'},
            {"input_std", required_argument, nullptr, 's'},
            {"num_results", required_argument, nullptr, 'r'},
            {"max_profiling_buffer_entries", required_argument, nullptr, 'e'},
            {"warmup_runs", required_argument, nullptr, 'w'},
            {"gl_backend", required_argument, nullptr, 'g'},
            {"hexagon_delegate", required_argument, nullptr, 'j'},
            {"xnnpack_delegate", required_argument, nullptr, 'x'},
            {nullptr, 0, nullptr, 0}};
 
        /* getopt_long stores the option index here. */
        int option_index = 0;
 
        c = getopt_long(argc, argv,
                        "a:b:c:d:e:f:g:i:j:l:m:p:r:s:t:v:w:x:", long_options,
                        &option_index);
 
        /* Detect the end of the options. */
        if (c == -1)
          break;
 
        switch (c)
        {
        case 'a':
          s.accel = strtol(optarg, nullptr, 10); // NOLINT(runtime/deprecated_fn)
          break;
        case 'b':
          s.input_mean = strtod(optarg, nullptr);
          break;
        case 'c':
          s.loop_count =
              strtol(optarg, nullptr, 10); // NOLINT(runtime/deprecated_fn)
          break;
        case 'e':
          s.max_profiling_buffer_entries =
              strtol(optarg, nullptr, 10); // NOLINT(runtime/deprecated_fn)
          break;
        case 'f':
          s.allow_fp16 =
              strtol(optarg, nullptr, 10); // NOLINT(runtime/deprecated_fn)
          break;
        case 'g':
          s.gl_backend =
              strtol(optarg, nullptr, 10); // NOLINT(runtime/deprecated_fn)
          break;
        case 'i':
          s.input_bmp_name = optarg;
          break;
        case 'j':
          s.hexagon_delegate = optarg;
          break;
        case 'l':
          s.labels_file_name = optarg;
          break;
        case 'm':
          s.model_name = optarg;
          break;
        case 'p':
          s.profiling =
              strtol(optarg, nullptr, 10); // NOLINT(runtime/deprecated_fn)
          break;
        case 'r':
          s.number_of_results =
              strtol(optarg, nullptr, 10); // NOLINT(runtime/deprecated_fn)
          break;
        case 's':
          s.input_std = strtod(optarg, nullptr);
          break;
        case 't':
          s.number_of_threads = strtol( // NOLINT(runtime/deprecated_fn)
              optarg, nullptr, 10);
          break;
        case 'v':
          s.verbose =
              strtol(optarg, nullptr, 10); // NOLINT(runtime/deprecated_fn)
          break;
        case 'w':
          s.number_of_warmup_runs =
              strtol(optarg, nullptr, 10); // NOLINT(runtime/deprecated_fn)
          break;
        case 'x':
          s.xnnpack_delegate =
              strtol(optarg, nullptr, 10); // NOLINT(runtime/deprecated_fn)
          break;
        case 'h':
        case '?':
          /* getopt_long already printed an error message. */
          exit(-1);
        default:
          exit(-1);
        }
      }
 
      delegate_providers.MergeSettingsIntoParams(s);
      RunInference(&s);
      return 0;
    }
 
  } // namespace label_image
} // namespace tflite
 
int main(int argc, char **argv)
{
  return tflite::label_image::Main(argc, argv);
}

运行指令 ./ws_app --tflite_model libnewpalm_detection.tflite --image tfimg对比推理前的输入一致

Android端

C++中TensorflowLite模型验证的过程

开发板上

C++中TensorflowLite模型验证的过程

对比推理后的输出一致 Android端

C++中TensorflowLite模型验证的过程

开发板端

C++中TensorflowLite模型验证的过程

到此,关于“C++中TensorflowLite模型验证的过程”的学习就结束了,希望能够解决大家的疑惑。理论与实践的搭配能更好的帮助大家学习,快去试试吧!若想继续学习更多相关知识,请继续关注亿速云网站,小编会继续努力为大家带来更多实用的文章!

推荐阅读:
  1. C++对象模型
  2. php制作验证码的过程

免责声明:本站发布的内容(图片、视频和文字)以原创、转载和分享为主,文章观点不代表本网站立场,如果涉及侵权请联系站长邮箱:is@yisu.com进行举报,并提供相关证据,一经查实,将立刻删除涉嫌侵权内容。

c++

上一篇:SpringBoot如何整合OpenApi

下一篇:go语言闭包的知识点整理

相关阅读

您好,登录后才能下订单哦!

密码登录
登录注册
其他方式登录
点击 登录注册 即表示同意《亿速云用户服务条款》