定点识别


基于object_detection训练自己的模型

花了不知道多少天、、主要参加一个定点识别的比赛、算是把模型搞定了、虽然结果十分的令人喜感(哈哈、不说了)、、难度有一点大(主要是各种天坑、在这里记录一下)

这是阿里天池的比赛、比赛给出上万张图片主要是服装、要在每个图片上识别出服装每个关键点、并将识别结果的坐标输出、比如左袖口什么的、差不多有24个标签吧、训练集给出的是每个图片的所有关键点的坐标、我的思路是先根据坐标
转化成矩形框(同时对x和y加上自己定义的距离数)、然后通过object_detection确定定位的位置、最后在进行输出(求两个x和两个y的平均来得到中心点)、具体步骤如下:

根据lable切分图片

这个脚本主要是根据lable对图片进行切分、根据lable创建若干个文件夹、切好的图片放到每个对应的文件加下、切分完得到几十万张图片(此刻的内心是奔溃的)、

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import csv
import cv2
import os

path=os.getcwd()
#自己定义框的宽度wide
def drawcnts_and_cut(original_img,x,y,wide):
x1=x-wide
x2=x+wide
y1=y-wide
y2=y+wide
crop_img = original_img[y1:y2, x1:x2]
return crop_img

def start(img_path,save_path,x,y):
original_img= cv2.imread(img_path)
crop_img = drawcnts_and_cut(original_img,int(x),int(y),25)
cv2.imwrite(save_path, crop_img)
def datatranslate(data):
splited=str(data).split()
return splited[0],splited[1]

#自己根据标签数量来改
lable=['class1', 'class2']
csv_reader = csv.reader(open('train\\input.csv', encoding='utf-8'))
num=0
for row in csv_reader:
for i in range(2,26,1):
photo=row[0]
data=row[i]
category=lable[i]
splited = str(row[i]).split("_")
print(photo)
print(num)
if int(splited[0])!=-1:
lib = path + "\\train\\"+photo
savepath=path+"\\output\\"+str(category)+"\\"+str(category)+"+"+str(num)+".jpg"
num+=1
start(lib,savepath,splited[0],splited[1])

将图片转化为对应的xml文件

默认的边框大小为整个图片的d、长度和宽度可以从图片中获取、最终批量的生成xml文件(突然想起比赛的图片切分后生成的30万个文件、还只能分批次的复制、一复制就卡屏、迷醉、、)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import os, sys
import glob
from PIL import Image

#根据实际来添加class
list=["class1","class2"]
for a in list:
path=os.getcwd()
#图像存储位置
src_img_dir = path+"\\input2\\"+a
# xml文件存放位置
src_xml_dir = path+"\\input2\\"+a
img_Lists = glob.glob(src_img_dir + '\*.jpg')
img_basenames = []
for item in img_Lists:
img_basenames.append(os.path.basename(item))
img_names = []
for item in img_basenames:
temp1, temp2 = os.path.splitext(item)
img_names.append(temp1)
for img in img_names:
im = Image.open((src_img_dir + '/' + img + '.jpg'))
width, height = im.size
xml_file = open((src_xml_dir + '/' + img + '.xml'), 'w')
xml_file.write('<annotation>\n')
xml_file.write(' <folder>'+a+'</folder>\n')
xml_file.write(' <filename>' + str(img) + '.jpg' + '</filename>\n')
xml_file.write(' <path>' + path +"\\input2\\"+a+"\\"+ str(img) + '.jpg'+ '</path>\n')
xml_file.write(' <size>\n')
xml_file.write(' <width>' + str(width) + '</width>\n')
xml_file.write(' <height>' + str(height) + '</height>\n')
xml_file.write(' <depth>3</depth>\n')
xml_file.write(' </size>\n')
xml_file.write(' <segmented>0</segmented>\n')
xml_file.write(' <object>\n')
xml_file.write(' <name>' + str(img) + '</name>\n')
xml_file.write(' <pose>Unspecified</pose>\n')
xml_file.write(' <truncated>1</truncated>\n')
xml_file.write(' <difficult>0</difficult>\n')
xml_file.write(' <bndbox>\n')
xml_file.write(' <xmin>' + "0" + '</xmin>\n')
xml_file.write(' <ymin>' + "0" + '</ymin>\n')
xml_file.write(' <xmax>' + str(width) + '</xmax>\n')
xml_file.write(' <ymax>' + str(height) + '</ymax>\n')
xml_file.write(' </bndbox>\n')
xml_file.write(' </object>\n')
xml_file.write('</annotation>')

xml转csv文件合并csv文件

要使用如下脚本将xml文件转化为csv文件、最后再把每个目录下的csv文件进行合并(注意删除重复的lable)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
#xml转csv文件合并csv文件

import os
import glob
import pandas as pd
import xml.etree.ElementTree as ET
tag=['class1','class2']
num=0

def xml_to_csv(path):
xml_list = []
for xml_file in glob.glob(path + '/*.xml'):
tree = ET.parse(xml_file)
root = tree.getroot()
for member in root.findall('object'):
value = (root.find('filename').text,
int(root.find('size')[0].text),
int(root.find('size')[1].text),
root.find('folder').text,
int(member[4][0].text),
int(member[4][1].text),
int(member[4][2].text),
int(member[4][3].text)
)
xml_list.append(value)
column_name = ['filename', 'width', 'height', 'class', 'xmin', 'ymin', 'xmax', 'ymax']
xml_df = pd.DataFrame(xml_list, columns=column_name)
return xml_df


def main():
for a in tag:
image_path = os.path.join(os.getcwd(), 'input2\\'+a)
xml_df = xml_to_csv(image_path)
xml_df.to_csv('data\\'+str(a)+'.csv',index=None)
print('Successfully converted xml to csv.')


main()

通过shell批量合并csv

1
2
3
4
5
6
7
@echo off
E:
cd add
dir
copy *.csv all_keywords.csv
echo 合并成功!'
pause

调用object_detection前的准备

下面是很有参考性的博客和官方的地址
https://blog.csdn.net/honk2012/article/details/79099651
https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/installation.md
可以翻墙的话推荐下面这篇、这个towardsdatascience还是很不错的
https://towardsdatascience.com/how-to-train-your-own-object-detector-with-tensorflows-object-detector-api-bec72ecfe1d9
基本后面的训练和模型的调用都是在github上的、想普通的个人电脑用ssd的一个mobile就行了、别的根本跑不动、batch设置的越大每次迭代的时间越长、如果太大电脑配置不够的话你就可以重新开机了、、
顺便说说几个坑官方步骤中的 protoc object_detection/protos/*.proto –python_out=. 如果是在window下要下载3.4版本的3.5会有bug
object_detection初始化一定要先执行、不然会给你各种报错、、
官方文档中export PYTHONPATH=$PYTHONPATH:pwd:pwd/slim 如果是windows下执行要用这个命令(查了很久用了很多的坑爹方法、只能说项目对windows不友好)SET PYTHONPATH=%cd%;%cd%\slim 执行目录还是不变
注意这几个坑基本就会很顺畅了、还有一些其他小坑一时想不起来、想到了再加、

文章目录
  1. 1. 根据lable切分图片
  2. 2. 将图片转化为对应的xml文件
  3. 3. xml转csv文件合并csv文件
  4. 4. 调用object_detection前的准备