【Android】人脸检测MTCNN移植到 Android安卓
本帖最后由 vcvycy 于 2018-6-27 21:03 编辑今天刚做好了,断断续续搞了一周。终于改好了,MTCNN移植到Android。
做了一个简单的demo。
虽然是实验室项目需要的,还是放出来,不知道有没有人需要。这个是人脸检测,就是从图中框出人脸,还有一个人脸识别,google的facenet,我也移植到安卓了。有人需要再加进来。主要参考自https://github.com/AITTSMD/MTCNN-Tensorflow
大致流程:
一、Tensorflow 模型固化
将PNet、ONet、RNet 网络参数.npy固化成.pb格式,方便java载入 固化后的文件在assets中,文件名mtcnn_freezed_model.pb
二、引入android tensorflow lite 库
只需在build.gradle(module)最后添加以下几行语句即可。参考自官网。
allprojects {
repositories {
jcenter()
}
}
dependencies {
compile 'org.tensorflow:tensorflow-android:+'
}
三、看MTCNN论文+看MTCNN python实现,然后改成java
有很多坑,比如论文很多细节没讲清,比如android版tensorflow lite 资料太少;Bitmap需要沿着对角线翻转再传入神经网络。然后就差不多了。
四、核心代码【主要3个文件,加起来代码不多,大概600行,全贴出来】
(1)MTCNN.JAVA
package com.example.vcvyc.mtcnn_new;
/*
MTCNN For Android
by cjf@xmu 20180625
*/
import android.content.ContentUris;
import android.content.res.AssetManager;
import android.graphics.Bitmap;
import android.graphics.Matrix;
import android.graphics.Point;
import android.graphics.Rect;
import android.support.v4.app.NotificationCompat;
import android.util.Log;
import android.widget.ImageView;
import org.tensorflow.Operation;
import org.tensorflow.contrib.android.TensorFlowInferenceInterface;
import java.util.Vector;
import static java.lang.Math.copySign;
import static java.lang.Math.floor;
import static java.lang.Math.max;
import static java.lang.Math.min;
import static java.lang.Math.scalb;
public class MTCNN {
//参数
private float factor=0.709f;
private float PNetThreshold=0.6f;
private float RNetThreshold=0.7f;
private float ONetThreshold=0.7f;
//MODEL PATH
private static final String MODEL_FILE= "file:///android_asset/mtcnn_freezed_model.pb";
//tensor name
private static final String PNetInName="pnet/input:0";
private static final String[] PNetOutName =new String[]{"pnet/prob1:0","pnet/conv4-2/BiasAdd:0"};
private static final String RNetInName="rnet/input:0";
private static final String[] RNetOutName =new String[]{ "rnet/prob1:0","rnet/conv5-2/conv5-2:0",};
private static final String ONetInName="onet/input:0";
private static final String[] ONetOutName =new String[]{ "onet/prob1:0","onet/conv6-2/conv6-2:0","onet/conv6-3/conv6-3:0"};
//安卓相关
publiclong lastProcessTime; //最后一张图片处理的时间ms
private static final String TAG="MTCNN";
private AssetManager assetManager;
private TensorFlowInferenceInterface inferenceInterface;
MTCNN(AssetManager mgr){
assetManager=mgr;
loadModel();
}
private boolean loadModel() {
//AssetManager
try {
inferenceInterface = new TensorFlowInferenceInterface(assetManager, MODEL_FILE);
Log.d("Facenet","load model success");
}catch(Exception e){
Log.e("Facenet","load model failed"+e);
return false;
}
return true;
}
//读取Bitmap像素值,预处理(-127.5 /128),转化为一维数组返回
private float[] normalizeImage(Bitmap bitmap){
int w=bitmap.getWidth();
int h=bitmap.getHeight();
float[] floatValues=new float;
int[] intValues=new int;
bitmap.getPixels(intValues,0,bitmap.getWidth(),0,0,bitmap.getWidth(),bitmap.getHeight());
float imageMean=127.5f;
float imageStd=128;
for (int i=0;i<intValues.length;i++){
final int val=intValues;
floatValues = (((val >> 16) & 0xFF) - imageMean) / imageStd;
floatValues = (((val >> 8) & 0xFF) - imageMean) / imageStd;
floatValues = ((val & 0xFF) - imageMean) / imageStd;
}
return floatValues;
}
/*
检测人脸,minSize是最小的人脸像素值
*/
private Bitmap bitmapResize(Bitmap bm, float scale) {
int width = bm.getWidth();
int height = bm.getHeight();
// CREATE A MATRIX FOR THE MANIPULATION。matrix指定图片仿射变换参数
Matrix matrix = new Matrix();
// RESIZE THE BIT MAP
matrix.postScale(scale, scale);
Bitmap resizedBitmap = Bitmap.createBitmap(
bm, 0, 0, width, height, matrix, true);
return resizedBitmap;
}
//输入前要翻转,输出也要翻转
privateint PNetForward(Bitmap bitmap,float [][]PNetOutProb,float[][][]PNetOutBias){
int w=bitmap.getWidth();
int h=bitmap.getHeight();
float[] PNetIn=normalizeImage(bitmap);
Utils.flip_diag(PNetIn,h,w,3); //沿着对角线翻转
inferenceInterface.feed(PNetInName,PNetIn,1,w,h,3);
inferenceInterface.run(PNetOutName,false);
int PNetOutSizeW=(int)Math.ceil(w*0.5-5);
int PNetOutSizeH=(int)Math.ceil(h*0.5-5);
float[] PNetOutP=new float;
float[] PNetOutB=new float;
inferenceInterface.fetch(PNetOutName,PNetOutP);
inferenceInterface.fetch(PNetOutName,PNetOutB);
//【写法一】先翻转,后转为2/3维数组
Utils.flip_diag(PNetOutP,PNetOutSizeW,PNetOutSizeH,2);
Utils.flip_diag(PNetOutB,PNetOutSizeW,PNetOutSizeH,4);
Utils.expand(PNetOutB,PNetOutBias);
Utils.expandProb(PNetOutP,PNetOutProb);
/*
*【写法二】这个比较快,快了3ms。意义不大,用上面的方法比较直观
for (int y=0;y<PNetOutSizeH;y++)
for (int x=0;x<PNetOutSizeW;x++){
int idx=PNetOutSizeH*x+y;
PNetOutProb=PNetOutP;
for(int i=0;i<4;i++)
PNetOutBias=PNetOutB;
}
*/
return 0;
}
//Non-Maximum Suppression
//nms,不符合条件的deleted设置为true
private void nms(Vector<Box> boxes,float threshold,String method){
//NMS.两两比对
//int delete_cnt=0;
int cnt=0;
for(int i=0;i<boxes.size();i++) {
Box box = boxes.get(i);
if (!box.deleted) {
//score<0表示当前矩形框被删除
for (int j = i + 1; j < boxes.size(); j++) {
Box box2=boxes.get(j);
if (!box2.deleted) {
int x1 = max(box.box, box2.box);
int y1 = max(box.box, box2.box);
int x2 = min(box.box, box2.box);
int y2 = min(box.box, box2.box);
if (x2 < x1 || y2 < y1) continue;
int areaIoU = (x2 - x1 + 1) * (y2 - y1 + 1);
float iou=0f;
if (method.equals("Union"))
iou = 1.0f*areaIoU / (box.area() + box2.area() - areaIoU);
else if (method.equals("Min")) {
iou = 1.0f * areaIoU / (min(box.area(), box2.area()));
Log.i(TAG,"iou="+iou);
}
if (iou >= threshold) { //删除prob小的那个框
if (box.score>box2.score)
box2.deleted=true;
else
box.deleted=true;
//delete_cnt++;
}
}
}
}
}
//Log.i(TAG,"sum:"+boxes.size()+" delete:"+delete_cnt);
}
private int generateBoxes(float[][] prob,float[][][]bias,float scale,float threshold,Vector<Box> boxes){
int h=prob.length;
int w=prob.length;
//Log.i(TAG,"height:"+prob.length+" width:"+prob.length);
for (int y=0;y<h;y++)
for (int x=0;x<w;x++){
float score=prob;
//only accept prob >threadshold(0.6 here)
if (score>PNetThreshold){
Box box=new Box();
//score
box.score=score;
//box
box.box=Math.round(x*2/scale);
box.box=Math.round(y*2/scale);
box.box=Math.round((x*2+11)/scale);
box.box=Math.round((y*2+11)/scale);
//bbr
for(int i=0;i<4;i++)
box.bbr=bias;
//add
boxes.addElement(box);
}
}
return 0;
}
private void BoundingBoxReggression(Vector<Box> boxes){
for (int i=0;i<boxes.size();i++)
boxes.get(i).calibrate();
}
//Pnet + Bounding Box Regression + Non-Maximum Regression
/* NMS执行完后,才执行Regression
* (1) For each scale , use NMS with threshold=0.5
* (2) For all candidates , use NMS with threshold=0.7
* (3) Calibrate Bounding Box
* 注意:CNN输入图片最上面一行,坐标为。所以Bitmap需要对折后再跑网络;网络输出同理.
*/
private Vector<Box> PNet(Bitmap bitmap,int minSize){
int whMin=min(bitmap.getWidth(),bitmap.getHeight());
float currentFaceSize=minSize;//currentFaceSize=minSize/(factor^k) k=0,1,2... until excced whMin
Vector<Box> totalBoxes=new Vector<Box>();
//【1】Image Paramid and Feed to Pnet
while (currentFaceSize<=whMin){
float scale=12.0f/currentFaceSize;
//(1)Image Resize
Bitmap bm=bitmapResize(bitmap,scale);
int w=bm.getWidth();
int h=bm.getHeight();
//(2)RUN CNN
int PNetOutSizeW=(int)(Math.ceil(w*0.5-5)+0.5);
int PNetOutSizeH=(int)(Math.ceil(h*0.5-5)+0.5);
float[][] PNetOutProb=new float;;
float[][][] PNetOutBias=new float;
PNetForward(bm,PNetOutProb,PNetOutBias);
//(3)数据解析
Vector<Box> curBoxes=new Vector<Box>();
generateBoxes(PNetOutProb,PNetOutBias,scale,PNetThreshold,curBoxes);
//Log.i(TAG,"CNN Output Box number:"+curBoxes.size()+" Scale:"+scale);
//(4)nms 0.5
nms(curBoxes,0.5f,"Union");
//(5)add to totalBoxes
for (int i=0;i<curBoxes.size();i++)
if (!curBoxes.get(i).deleted)
totalBoxes.addElement(curBoxes.get(i));
//Face Size等比递增
currentFaceSize/=factor;
}
//NMS 0.7
nms(totalBoxes,0.7f,"Union");
//BBR
BoundingBoxReggression(totalBoxes);
return Utils.updateBoxes(totalBoxes);
}
//截取box中指定的矩形框(越界要处理),并resize到size*size大小,返回数据存放到data中。
public Bitmap tmp_bm;
private void crop_and_resize(Bitmap bitmap,Box box,int size,float[] data){
//(2)crop and resize
Matrix matrix = new Matrix();
float scale=1.0f*size/box.width();
matrix.postScale(scale, scale);
Bitmap croped=Bitmap.createBitmap(bitmap, box.left(),box.top(),box.width(), box.height(),matrix,true);
//(3)save
int[] pixels_buf=new int;
croped.getPixels(pixels_buf,0,croped.getWidth(),0,0,croped.getWidth(),croped.getHeight());
float imageMean=127.5f;
float imageStd=128;
for (int i=0;i<pixels_buf.length;i++){
final int val=pixels_buf;
data = (((val >> 16) & 0xFF) - imageMean) / imageStd;
data = (((val >> 8) & 0xFF) - imageMean) / imageStd;
data = ((val & 0xFF) - imageMean) / imageStd;
}
}
/*
* RNET跑神经网络,将score和bias写入boxes
*/
private void RNetForward(float[] RNetIn,Vector<Box>boxes){
int num=RNetIn.length/24/24/3;
//feed & run
inferenceInterface.feed(RNetInName,RNetIn,num,24,24,3);
inferenceInterface.run(RNetOutName,false);
//fetch
float[] RNetP=new float;
float[] RNetB=new float;
inferenceInterface.fetch(RNetOutName,RNetP);
inferenceInterface.fetch(RNetOutName,RNetB);
//转换
for (int i=0;i<num;i++) {
boxes.get(i).score = RNetP;
for (int j=0;j<4;j++)
boxes.get(i).bbr=RNetB;
}
}
//Refine Net
private Vector<Box> RNet(Bitmap bitmap,Vector<Box> boxes){
//RNet Input Init
int num=boxes.size();
float[] RNetIn=new float;
float[] curCrop=new float;
int RNetInIdx=0;
for (int i=0;i<num;i++){
crop_and_resize(bitmap,boxes.get(i),24,curCrop);
Utils.flip_diag(curCrop,24,24,3);
//Log.i(TAG,"Pixels values:"+curCrop+" "+curCrop);
for (int j=0;j<curCrop.length;j++) RNetIn= curCrop;
}
//Run RNet
RNetForward(RNetIn,boxes);
//RNetThreshold
for (int i=0;i<num;i++)
if (boxes.get(i).score<PNetThreshold)
boxes.get(i).deleted=true;
//Nms
nms(boxes,0.7f,"Union");
BoundingBoxReggression(boxes);
return Utils.updateBoxes(boxes);
}
/*
* ONet跑神经网络,将score和bias写入boxes
*/
private void ONetForward(float[] ONetIn,Vector<Box>boxes){
int num=ONetIn.length/48/48/3;
//feed & run
inferenceInterface.feed(ONetInName,ONetIn,num,48,48,3);
inferenceInterface.run(ONetOutName,false);
//fetch
float[] ONetP=new float; //prob
float[] ONetB=new float; //bias
float[] ONetL=new float; //landmark
inferenceInterface.fetch(ONetOutName,ONetP);
inferenceInterface.fetch(ONetOutName,ONetB);
inferenceInterface.fetch(ONetOutName,ONetL);
//转换
for (int i=0;i<num;i++) {
//prob
boxes.get(i).score = ONetP;
//bias
for (int j=0;j<4;j++)
boxes.get(i).bbr=ONetB;
//landmark
for (int j=0;j<5;j++) {
int x=boxes.get(i).left()+(int) (ONetL*boxes.get(i).width());
int y= boxes.get(i).top()+(int) (ONetL*boxes.get(i).height());
boxes.get(i).landmark = new Point(x,y);
//Log.i(TAG," landmarkd "+x+ ""+y);
}
}
}
//ONet
private Vector<Box> ONet(Bitmap bitmap,Vector<Box> boxes){
//ONet Input Init
int num=boxes.size();
float[] ONetIn=new float;
float[] curCrop=new float;
int ONetInIdx=0;
for (int i=0;i<num;i++){
crop_and_resize(bitmap,boxes.get(i),48,curCrop);
Utils.flip_diag(curCrop,48,48,3);
for (int j=0;j<curCrop.length;j++) ONetIn= curCrop;
}
//Run ONet
ONetForward(ONetIn,boxes);
//ONetThreshold
for (int i=0;i<num;i++)
if (boxes.get(i).score<ONetThreshold)
boxes.get(i).deleted=true;
BoundingBoxReggression(boxes);
//Nms
nms(boxes,0.7f,"Min");
return Utils.updateBoxes(boxes);
}
private void square_limit(Vector<Box>boxes,int w,int h){
//square
for (int i=0;i<boxes.size();i++) {
boxes.get(i).toSquareShape();
boxes.get(i).limit_square(w,h);
}
}
/*
* 参数:
* bitmap:要处理的图片
* minFaceSize:最小的人脸像素值.(此值越大,检测越快)
* 返回:
* 人脸框
*/
public Vector<Box> detectFaces(Bitmap bitmap,int minFaceSize) {
long t_start = System.currentTimeMillis();
//【1】PNet generate candidate boxes
Vector<Box> boxes=PNet(bitmap,minFaceSize);
square_limit(boxes,bitmap.getWidth(),bitmap.getHeight());
//【2】RNet
boxes=RNet(bitmap,boxes);
square_limit(boxes,bitmap.getWidth(),bitmap.getHeight());
//【3】ONet
boxes=ONet(bitmap,boxes);
//return
Log.i(TAG,"Mtcnn Detection Time:"+(System.currentTimeMillis()-t_start));
lastProcessTime=(System.currentTimeMillis()-t_start);
returnboxes;
}
}
(2)Utils.Java
package com.example.vcvyc.mtcnn_new;
/*
MTCNN For Android
by cjf@xmu 20180625
*/
import android.graphics.Bitmap;
import android.graphics.Canvas;
import android.graphics.Color;
import android.graphics.Paint;
import android.graphics.Point;
import android.graphics.Rect;
import android.util.Log;
import android.widget.ImageView;
import java.util.Vector;
public class Utils {
//复制图片,并设置isMutable=true
public static Bitmap copyBitmap(Bitmap bitmap){
return bitmap.copy(bitmap.getConfig(),true);
}
//在bitmap中画矩形
public static void drawRect(Bitmap bitmap,Rect rect){
try {
Canvas canvas = new Canvas(bitmap);
Paint paint = new Paint();
int r=255;//(int)(Math.random()*255);
int g=0;//(int)(Math.random()*255);
int b=0;//(int)(Math.random()*255);
paint.setColor(Color.rgb(r, g, b));
paint.setStrokeWidth(1+bitmap.getWidth()/500 );
paint.setStyle(Paint.Style.STROKE);
canvas.drawRect(rect, paint);
}catch (Exception e){
Log.i("Utils"," error"+e);
}
}
//在图中画点
public static void drawPoints(Bitmap bitmap, Point[] landmark){
for (int i=0;i<landmark.length;i++){
int x=landmark.x;
int y=landmark.y;
//Log.i("Utils"," landmarkd "+x+ ""+y);
drawRect(bitmap,new Rect(x-1,y-1,x+1,y+1));
}
}
//Flip alone diagonal
//对角线翻转。data大小原先为h*w*stride,翻转后变成w*h*stride
public static void flip_diag(float[]data,int h,int w,int stride){
float[] tmp=new float;
for (int i=0;i<w*h*stride;i++) tmp=data;
for (int y=0;y<h;y++)
for (int x=0;x<w;x++){
for (int z=0;z<stride;z++)
data[(x*h+y)*stride+z]=tmp[(y*w+x)*stride+z];
}
}
//src转为二维存放到dst中
public static void expand(float[] src,float[][]dst){
int idx=0;
for (int y=0;y<dst.length;y++)
for (int x=0;x<dst.length;x++)
dst=src;
}
//src转为三维存放到dst中
public static void expand(float[] src,float[][][] dst){
int idx=0;
for (int y=0;y<dst.length;y++)
for (int x=0;x<dst.length;x++)
for (int c=0;c<dst.length;c++)
dst=src;
}
//dst=src[:,:,1]
public static void expandProb(float[] src,float[][]dst){
int idx=0;
for (int y=0;y<dst.length;y++)
for (int x=0;x<dst.length;x++)
dst=src;
}
//box转化为rect
public static Rect[] boxes2rects(Vector<Box> boxes){
int cnt=0;
for (int i=0;i<boxes.size();i++) if (!boxes.get(i).deleted) cnt++;
Rect[] r=new Rect;
int idx=0;
for (int i=0;i<boxes.size();i++)
if (!boxes.get(i).deleted)
r=boxes.get(i).transform2Rect();
return r;
}
//删除做了delete标记的box
public static Vector<Box> updateBoxes(Vector<Box> boxes){
Vector<Box> b=new Vector<Box>();
for (int i=0;i<boxes.size();i++)
if (!boxes.get(i).deleted)
b.addElement(boxes.get(i));
return b;
}
//
static public void showPixel(int v){
Log.i("MainActivity","Pixel:R"+((v>>16)&0xff)+"G:"+((v>>8)&0xff)+ " B:"+(v&0xff));
}
}
(3)Box.java 【保存人脸框+人脸关键点(眼睛鼻子嘴巴)】package com.example.vcvyc.mtcnn_new;
/*
MTCNN For Android
by cjf@xmu 20180625
*/
import android.graphics.Point;
import android.graphics.Rect;
import android.util.Log;
import static java.lang.Math.max;
import static java.lang.Math.min;
public class Box {
publicint[] box; //left:box,top:box,right:box,bottom:box
publicfloat score; //probability
publicfloat[] bbr; //bounding box regression
publicboolean deleted;
publicPoint[] landmark; //facial landmark.只有ONet输出Landmark
Box(){
box=new int;
bbr=new float;
deleted=false;
landmark=new Point;
}
public int left(){return box;}
public int right(){return box;}
public int top(){return box;}
public int bottom(){return box;}
public int width(){return box-box+1;}
public int height(){return box-box+1;}
//转为rect
public Rect transform2Rect(){
Rect rect=new Rect();
rect.left=Math.round(box);
rect.top=Math.round(box);
rect.right=Math.round(box);
rect.bottom=Math.round(box);
returnrect;
}
//面积
publicint area(){
return width()*height();
}
//Bounding Box Regression
public void calibrate(){
int w=box-box+1;
int h=box-box+1;
box=(int)(box+w*bbr);
box=(int)(box+h*bbr);
box=(int)(box+w*bbr);
box=(int)(box+h*bbr);
for (int i=0;i<4;i++) bbr=0.0f;
}
//当前box转为正方形
public void toSquareShape(){
int w=width();
int h=height();
if (w>h){
box-=(w-h)/2;
box+=(w-h+1)/2;
}else{
box-=(h-w)/2;
box+=(h-w+1)/2;
}
}
//防止边界溢出,并维持square大小
public void limit_square(int w,int h){
if (box<0 || box<0){
int len=max(-box,-box);
box+=len;
box+=len;
}
if (box>=w || box>=h){
int len=max(box-w+1,box-h+1);
box-=len;
box-=len;
}
}
public void limit_square2(int w,int h){
if (width() > w) box-=width()-w;
if (height()> h) box-=height()-h;
if (box<0){
int sz=-box;
box+=sz;
box+=sz;
}
if (box<0){
int sz=-box;
box+=sz;
box+=sz;
}
if (box>=w){
int sz=box-w+1;
box-=sz;
box-=sz;
}
if (box>=h){
int sz=box-h+1;
box-=sz;
box-=sz;
}
}
}
最终项目:https://github.com/vcvycy/MTCNN4Android
效果图:
https://img-blog.csdn.net/20180626233548442?watermark/2/text/aHR0cHM6Ly9ibG9nLmNzZG4ubmV0L3ZjdnljeQ==/font/5a6L5L2T/fontsize/400/fill/I0JBQkFCMA==/dissolve/70https://img-blog.csdn.net/20180626233616728?watermark/2/text/aHR0cHM6Ly9ibG9nLmNzZG4ubmV0L3ZjdnljeQ==/font/5a6L5L2T/fontsize/400/fill/I0JBQkFCMA==/dissolve/70
注意:这个项目只包含人脸检测,也就是找出人脸框,不包含人脸识别(也就是无法识别当前人脸框里的人是谁)
人脸识别我在另一个帖子里发了,感兴趣的可以看看,链接:https://www.52pojie.cn/thread-758292-1-1.html。 lovenuanxin 发表于 2018-6-27 07:57
楼主可以试试腾讯的ncnn,c++源码开源,jni层可控;其实我现在缺计算特征值判断是否是同一个人,和实时处理 ...
我喜欢用google的东西~.
判断同一个人,用facenet,我也移植过去。但是想在安卓实时处理不好做,facenet效果好,但是慢。
你可以试试其他小一点的神经网络。
https://github.com/vcvycy/Android_Facenet 这是我代码。 湖北吴彦祖 发表于 2018-6-27 15:01
人脸解锁么??nice啊 。。。。扶你上去
这个只包含人脸检测。也就是从图中框出所有的人脸,然后定位眼睛/鼻子/嘴巴位置。
人脸识别才是识别图中的人是谁。
人脸识别我也实现了,不过在另一个项目中:
https://github.com/vcvycy/Android_Facenet 感谢分享,向大佬学习学习!!! 很不错的代码!看起来很有游戏,有时间玩玩。 像大佬致敬! 感谢楼主的热心指导!!!!!!!!! 感谢楼主分享 woc!!!这个是好东西啊 楼主可以试试腾讯的ncnn,c++源码开源,jni层可控;其实我现在缺计算特征值判断是否是同一个人,和实时处理,楼主有何建议么 谢谢分享,学习了