1.数据统计核心代码
import xmltodict
import os
from tqdm import tqdm
xml_dir = input(prompt="请输入xml文件夹地址:") #xml文件路径(Annotations)
if not os.path.exists(xml_dir):
print("[发生错误]:",xml_dir,"不存在")
else:
statistical_results = {}
#进度条功能
pbar = tqdm(total=len(os.listdir(xml_dir)))
pbar.set_description("VOC数据集类别统计:") # 设置前缀
for xml_file in os.listdir(xml_dir):
# 拼接xml文件的path
xml_file_path = os.path.join(xml_dir,xml_file)
# 读取xml文件到字符串
with open(xml_file_path) as f:
xml_str = f.read()
# xml字符串转为字典
try:
xml_dic = xmltodict.parse(xml_str)
except Exception as e:
print("[发生错误]:",xml_file_path,"解析失败")
# 获取xml文件中的所有objects
obj_list = xml_dic["annotation"]["object"]
if not isinstance(obj_list,list): # xml文件中包含多个object
obj_list = [obj_list]
# labels分布统计
for obj in obj_list:
if not obj['name'] in statistical_results.keys():
statistical_results[obj['name']] = 1
else:
statistical_results[obj['name']] += 1
#更新进度条
pbar.update(1)
#释放进度条
pbar.close()
#输出统计结果
cls_list = list(statistical_results.keys())
print("=================================================")
print("[统计报告]:")
print("class list:",cls_list)
print("类别总数:",len(cls_list))
print("类别分布情况:",statistical_results)
print("=================================================")
请输入xml文件夹地址:csf
[发生错误]: csf 不存在
请输入xml文件夹地址:C:\Users\itrb\Desktop\AirportApronDatasetLabel\label-finished\105-normal\Annotations
VOC数据集类别统计:: 100%|████████████████████████████████████████████████████████| 1943/1943 [00:01<00:00, 1416.14it/s] ?it/s]
=================================================
[统计报告]:
class list: ['BridgeVehicle', 'Person', 'FollowMe', 'Plane', 'LuggageTruck', 'RefuelingTruck', 'FoodTruck', 'Tractor']
类别总数: 8
类别分布情况: {'BridgeVehicle': 1943, 'Person': 2739, 'FollowMe': 4, 'Plane': 1789, 'LuggageTruck': 83, 'RefuelingTruck': 401, 'FoodTruck': 378, 'Tractor': 748}
=================================================
2.统计结果可视化
2.1 根据统计结果绘制条形图
# 绘制条形图
import matplotlib.pyplot as plt
import numpy as np
# 初始化画布
fig = plt.figure(figsize=(13,5),dpi=200)
# 添加一个子图
ax = plt.subplot(1,1,1)
# 绘制条形图
y_pos = np.arange(len(statistical_results))
ax.barh(y_pos,list(statistical_results.values()))
# 设置y轴的ticklabels
ax.set_yticks(y_pos)
ax.set_yticklabels(list(statistical_results.keys()))
# 设置图片的标题
ax.set_title("The total number of objects = {} in {} images".format(
np.sum(list(statistical_results.values())),len(os.listdir(xml_dir))
))
plt.show()
2.2 根据统计结果绘制折线图
# 绘制折线图
import matplotlib.pyplot as plt
import numpy as np
# 初始化画布
fig = plt.figure(figsize=(13,5),dpi=200)
# 添加一个子图x
ax = plt.subplot(1,1,1)
# 绘制折线图
x_pos = np.arange(len(statistical_results))
ax.plot(x_pos,list(statistical_results.values()))
# 设置x轴的ticklabels
ax.set_xticks(x_pos)
ax.set_xticklabels(list(statistical_results.keys()),rotation = 45)
# 设置图片的标题
ax.set_title("The total number of objects = {} in {} images".format(
np.sum(list(statistical_results.values())),len(os.listdir(xml_dir))
))
plt.show()
评论 (0)