Health/Assets/Scripts/PoseCheck/EstimateModel/MediaPipeEstimator.cs

189 lines
7.8 KiB
C#

using OpenCVForUnity.CoreModule;
using OpenCVForUnity.ImgprocModule;
using OpenCVForUnity.UnityUtils;
using OpenCVForUnityExample.DnnModel;
using System;
using System.Collections.Generic;
using Unity.Mathematics;
using UnityEngine;
using UnityEngine.Experimental.GlobalIllumination;
using static UnityEngine.Networking.UnityWebRequest;
using OpenCVRange = OpenCVForUnity.CoreModule.Range;
public class MediaPipeEstimator : Estimator
{
private MediaPipePersonDetector _personDetector;
private MediaPipePoseEstimator _poseEstimator;
public override void InitModel()
{
var personDetectModelPath = Utils.getFilePath(YogaConfig.MODEL_PATHS[ModelType.MediapipePersonDetect]);
if (string.IsNullOrEmpty(personDetectModelPath))
{
LogPrint.Error(YogaConfig.MODEL_PATHS[ModelType.MediapipePersonDetect] + " is not loaded. Please read “StreamingAssets/OpenCVForUnity/dnn/setup_dnn_module.pdf” to make the necessary setup.");
}
else
{
_personDetector = new MediaPipePersonDetector(personDetectModelPath, 0.3f, 0.6f, 5000);
}
var poseModelPath = Utils.getFilePath(YogaConfig.MODEL_PATHS[ModelType.MediapipePose]);
if (string.IsNullOrEmpty(poseModelPath))
{
LogPrint.Error(YogaConfig.MODEL_PATHS[ModelType.MediapipePose] + " is not loaded. Please read “StreamingAssets/OpenCVForUnity/dnn/setup_dnn_module.pdf” to make the necessary setup.");
}
else
{
_poseEstimator = new MediaPipePoseEstimator(poseModelPath, 0.9f);
}
}
public override bool Check(ref Mat img)
{
if (_personDetector == null || _poseEstimator == null)
{
Imgproc.putText(img, "model file is not loaded.", new Point(5, img.rows() - 30), Imgproc.FONT_HERSHEY_SIMPLEX, 0.7, new Scalar(255, 255, 255, 255), 2, Imgproc.LINE_AA, false);
Imgproc.putText(img, "Please read console message.", new Point(5, img.rows() - 10), Imgproc.FONT_HERSHEY_SIMPLEX, 0.7, new Scalar(255, 255, 255, 255), 2, Imgproc.LINE_AA, false);
return false;
}
return true;
}
public override bool Esitmate(Mat bgrMat, Mat rgbaMat, out List<Point> points)
{
points = null;
Mat results = _personDetector.infer(bgrMat);
if (results.rows() == 0)
return false;
float imgSize = Mathf.Max((float)bgrMat.size().width, (float)bgrMat.size().height);
float x_factor = 1;
float y_factor = 1;
float x_shift = (imgSize - (float)bgrMat.size().width) / 2f;
float y_shift = (imgSize - (float)bgrMat.size().height) / 2f;
float maxSize = 30;
int maxRectIndex = -1;
for (int i = 0; i < results.rows(); ++i)
{
float[] results_arr = new float[4];
results.get(i, 0, results_arr);
float x1 = Mathf.Floor(results_arr[0] * x_factor - x_shift);
float y1 = Mathf.Floor(results_arr[1] * y_factor - y_shift);
float x2 = Mathf.Floor(results_arr[2] * x_factor - x_shift);
float y2 = Mathf.Floor(results_arr[3] * y_factor - y_shift);
//主副驾
if (!GlobalData.Instance.IsDriverPosition)
{
if ((x1 + x2) / 2 < bgrMat.size().width / 2) //副驾,主驾位跳过
{
continue;
}
}
//if (GlobalData.Instance.Position == PositionType.Driver)
else
{
if ((x1 + x2) / 2 >= bgrMat.size().width / 2) //主驾,副驾位跳过
{
continue;
}
}
float size = Mathf.Abs(x2 - x1) * Mathf.Abs(y2 - y1);
if (maxSize < size)
{
maxSize = size;
maxRectIndex = i;
}
float[] box = new float[4] { x1, y1, x2, y2 };
YogaManager.Instance.PersonRectResult = new List<float[]> { box };
}
if (maxRectIndex < 0) //没有找到合适
{
LogPrint.Log("没有找到合适的框");
return false;
}
List<Mat> result = _poseEstimator.infer(bgrMat, results.row(maxRectIndex));
//_poseEstimator.visualize(rgbaMat, results[0], false, false);
points = GetKeypoints(result[0]);
//没有检测到人体
return false;
}
public override void DebugPrint(ref Mat img, bool isRGB = false)
{
}
private List<Point> GetKeypoints(Mat result)
{
var retVal = new List<Point>();
if (result.empty() || result.rows() < 317)
return retVal;
//获取关键点
float[] conf = new float[1];
result.get(316, 0, conf);
float[] bbox = new float[4];
result.get(0, 0, bbox);
int auxiliary_points_num = 6;
Mat results_col4_199_39x5 = result.rowRange(new OpenCVRange(4, 199 - (5 * auxiliary_points_num))).reshape(1, 39 - auxiliary_points_num);
float[] landmarks_screen_xy = new float[(39 - auxiliary_points_num) * 2];
results_col4_199_39x5.colRange(new OpenCVRange(0, 2)).get(0, 0, landmarks_screen_xy);
float[] landmarks_screen_xyz = new float[(39 - auxiliary_points_num) * 3];
results_col4_199_39x5.colRange(new OpenCVRange(0, 3)).get(0, 0, landmarks_screen_xyz);
// # only show visible keypoints which presence bigger than 0.8
float[] landmarks_presence = new float[(39 - auxiliary_points_num)];
results_col4_199_39x5.colRange(new OpenCVRange(4, 5)).get(0, 0, landmarks_presence);
Mat results_col199_316_39x3 = result.rowRange(new OpenCVRange(199, 316 - (3 * auxiliary_points_num))).reshape(1, 39 - auxiliary_points_num);
float[] landmarks_world = new float[(39 - auxiliary_points_num) * 3];
results_col199_316_39x3.get(0, 0, landmarks_world);
//将关键点映射到现有的open pose关键点上
retVal.Add(GetPointData(landmarks_screen_xy, landmarks_presence, 0)); //Nose
retVal.Add(new Point(-1, -1)); //Neck
retVal.Add(GetPointData(landmarks_screen_xy, landmarks_presence, 11)); //LShoulder
retVal.Add(GetPointData(landmarks_screen_xy, landmarks_presence, 13)); //LElbow
retVal.Add(GetPointData(landmarks_screen_xy, landmarks_presence, 15)); //LWrist
retVal.Add(GetPointData(landmarks_screen_xy, landmarks_presence, 12)); //RShoulder
retVal.Add(GetPointData(landmarks_screen_xy, landmarks_presence, 14)); //RElbow
retVal.Add(GetPointData(landmarks_screen_xy, landmarks_presence, 16)); //RWrist
retVal.Add(GetPointData(landmarks_screen_xy, landmarks_presence, 23)); //LHip
retVal.Add(GetPointData(landmarks_screen_xy, landmarks_presence, 25)); //LKnee
retVal.Add(GetPointData(landmarks_screen_xy, landmarks_presence, 27)); //LAnkle
retVal.Add(GetPointData(landmarks_screen_xy, landmarks_presence, 24)); //RHip
retVal.Add(GetPointData(landmarks_screen_xy, landmarks_presence, 26)); //RKnee
retVal.Add(GetPointData(landmarks_screen_xy, landmarks_presence, 28)); //RAnkle
retVal.Add(GetPointData(landmarks_screen_xy, landmarks_presence, 2)); //LEye
retVal.Add(GetPointData(landmarks_screen_xy, landmarks_presence, 5)); //REye
retVal.Add(GetPointData(landmarks_screen_xy, landmarks_presence, 7)); //LEar
retVal.Add(GetPointData(landmarks_screen_xy, landmarks_presence, 8)); //REar
retVal.Add(new Point(-1, -1)); //Background
return retVal;
}
private Point GetPointData(float[] landmarks, float[] _landmarks_presence, int index)
{
if (_landmarks_presence[index] < 0.8f)
return null;
index = index * 2;
return new Point(landmarks[index], landmarks[index + 1]);
}
public override void DisposeModel()
{
}
}