近几天笔者深入学习了下机器学习、深度学习,不论是谷歌围棋AIAlphaGo
、还是目前使用的阿里云智能语音合成,都非常吸引人。连续多天的理论学习后,总体而言,绝大多数教程都围绕数学算法展开,而实际上我们的需求与算法之间,对新手而言还是非常不容易去匹配的。比如当下的简单图片分类模型。
0. 简介
turicreate
是苹果开源的一款人工智能工具,特点是简化了机器学习的开发模型,不必成为深度学习专家,就可以实现目标检测、图像分类、或与图像分类相似的分类。缺点是不支持windows,在wsl2里数次安装无果后,我直接写了个Docker,然后拿来就用。
turicreate
的简单性,下面体验过就知道,完全是无脑操作,不需要懂背后的一堆数学模型,是一个比较通用的分类工具。之前学习过一些Pytorch
的基础知识,在图像分类面前,目前还比较困难,后续我会逐步学透。
项目开源地址:apple/turicreate
1. 分类模型基本原理
需求:采集图像资源并进行人工处理,实现基本的分类操作,然后利用机器学习预测未知图像。
以连连看游戏为例,想要设计一种自动消除的机器智能,首先要做的就是机器感知出正确的图片,这样才可以输出二维矩阵,进行连连看消除算法。而图片感知这块,有一种理想状态是:截取卡片内中心点,对比RGB信息,如果一样则归为一类;但实际上由于截图不是非常精确或多点类似图片的影响,这种理想状态只适合比较单调色彩的识别上。
因此就有必要借助机器智能来分类了。
2. 实现步骤
A:数据收集和整理分类
turicreate的图片分类非常好操作。在当前目录创建data文件夹,里面放置按标签名命名的图片文件。每个文件夹对应一种分类,内部为具体的已标识的图片。
B: 训练数据
在Windows系统上,运行turicreate的docker版本,然后使用CMD或Vscode进入容器:
docker run -it --name tc -v C:\\Users\\baiyue\\Desktop\\Game自动化:/app baiyuetribe/turicreate
注意替换自己的windows目录。
然后新建train.py
。
import turicreate as tc
img_folder = 'data'
# 导入数据
data = tc.image_analysis.load_images(img_folder, with_path=True)
# 使用文件名来做标签
data['label'] = data['path'].apply(lambda path: path.split('/')[len(path.split('/')) - 2])
data.save('doraemon-walle.sframe')
# 百分之八十的数据用于训练,百分之二十用于测试
train_data, test_data = data.random_split(0.8, seed=2)
# 开始训练模型
model = tc.image_classifier.create(train_data, target='label')
# 测试模型
predictions = model.predict(test_data)
metrics = model.evaluate(test_data)
# 输出测试结果
print(metrics['accuracy'])
model.save('my_model_file')
上面代码无需做任何处理,然后运行python train.py
进行训练,完成后会在当前目录生成训练结果my_model_file
。
此步骤CPU是满负载的,大概40s后出结果。
总分类数为28种,样本有261个图片,最终精度为1.基本够胜任连连看的图片识别了。
C: 预测数据
训练完毕后,我们随机截截取几个新的图片,命名为1,2,3.png,然后用刚才训练的模型去预测。
当前目录创建predict.py
文件。
import turicreate as tc
loaded_model = tc.load_model('my_model_file')
def getDataset():
#data = tc.image_analysis.load_images('screenshot', with_path=True)
img_list = [str(i)+'.png' for i in range(1,4)]
result = []
for i in img_list:
data = tc.image_analysis.load_images(i, with_path=True) #图片文件名
result_arr = loaded_model.predict(data)[0]
result.append(result_arr)
return result
with open('result.txt','w',encoding='utf-8') as f:
data = getDataset()
print(data)
f.write(str(getDataset()))
运行python predict.py
后,输出预测结果为:['4', '18', '24']
.
对比分类图,可见预测结果是非常准确的。
3. 遗留问题
由于不支持Windows,我采用了自制Docker,使用体验上非常好,但是与windows本地交互不方便,比如本地环境下ps.system('docker exec -it tc python /app/train.py')
命令就无法正常执行,subprocess
也是一样,因此这里就出现了连接阻断。这样的话本地桌面截图就无法调用Docker容器内部的机器模型去识别了。针对这种情况,可以尝试做一个webapi接口,利用url请求传递待检测的图片,容器内部处理后把结果返回过来,就可以解决这种阻断。 事实上这是非常容易实现的方式,毕竟绝大数人工智能,最后都是以API接口的形式对外开放。