VOC数据集类别统计

jupiter
2021-01-24 / 0 评论 / 897 阅读 / 正在检测是否收录...
温馨提示:
本文最后更新于2022年10月05日,已超过790天没有更新,若内容或图片失效,请留言反馈。

1.数据统计核心代码

import xmltodict
import os
from tqdm import tqdm

xml_dir = input(prompt="请输入xml文件夹地址:") #xml文件路径(Annotations)
if not os.path.exists(xml_dir):
    print("[发生错误]:",xml_dir,"不存在")
else:
    statistical_results = {}

    #进度条功能
    pbar = tqdm(total=len(os.listdir(xml_dir)))
    pbar.set_description("VOC数据集类别统计:") # 设置前缀


    for xml_file in os.listdir(xml_dir):
        # 拼接xml文件的path
        xml_file_path = os.path.join(xml_dir,xml_file)

        # 读取xml文件到字符串
        with open(xml_file_path) as f:
            xml_str = f.read()

        # xml字符串转为字典
        try:
            xml_dic = xmltodict.parse(xml_str)
        except Exception as e:
            print("[发生错误]:",xml_file_path,"解析失败")
        # 获取xml文件中的所有objects
        obj_list = xml_dic["annotation"]["object"]
        if not isinstance(obj_list,list): # xml文件中包含多个object
            obj_list = [obj_list]


        # labels分布统计
        for obj in obj_list:
            if not obj['name'] in statistical_results.keys():
                statistical_results[obj['name']] = 1
            else:
                statistical_results[obj['name']] += 1

        #更新进度条
        pbar.update(1)

    #释放进度条
    pbar.close()

    #输出统计结果
    cls_list = list(statistical_results.keys())
    print("=================================================")
    print("[统计报告]:")
    print("class list:",cls_list)
    print("类别总数:",len(cls_list))
    print("类别分布情况:",statistical_results)
    print("=================================================")
请输入xml文件夹地址:csf
[发生错误]: csf 不存在
请输入xml文件夹地址:C:\Users\itrb\Desktop\AirportApronDatasetLabel\label-finished\105-normal\Annotations
VOC数据集类别统计:: 100%|████████████████████████████████████████████████████████| 1943/1943 [00:01<00:00, 1416.14it/s] ?it/s]
=================================================
[统计报告]:
class list: ['BridgeVehicle', 'Person', 'FollowMe', 'Plane', 'LuggageTruck', 'RefuelingTruck', 'FoodTruck', 'Tractor']
类别总数: 8
类别分布情况: {'BridgeVehicle': 1943, 'Person': 2739, 'FollowMe': 4, 'Plane': 1789, 'LuggageTruck': 83, 'RefuelingTruck': 401, 'FoodTruck': 378, 'Tractor': 748}
=================================================

2.统计结果可视化

2.1 根据统计结果绘制条形图

# 绘制条形图
import matplotlib.pyplot as plt
import numpy as np

# 初始化画布
fig = plt.figure(figsize=(13,5),dpi=200)

# 添加一个子图
ax = plt.subplot(1,1,1)

# 绘制条形图
y_pos = np.arange(len(statistical_results))
ax.barh(y_pos,list(statistical_results.values()))

# 设置y轴的ticklabels
ax.set_yticks(y_pos)
ax.set_yticklabels(list(statistical_results.keys()))

# 设置图片的标题
ax.set_title("The total number of objects = {} in {} images".format(
    np.sum(list(statistical_results.values())),len(os.listdir(xml_dir))
))

plt.show()

image-20210124110737704

2.2 根据统计结果绘制折线图

# 绘制折线图
import matplotlib.pyplot as plt
import numpy as np

# 初始化画布
fig = plt.figure(figsize=(13,5),dpi=200)

# 添加一个子图x
ax = plt.subplot(1,1,1)

# 绘制折线图
x_pos = np.arange(len(statistical_results))
ax.plot(x_pos,list(statistical_results.values()))

# 设置x轴的ticklabels
ax.set_xticks(x_pos)
ax.set_xticklabels(list(statistical_results.keys()),rotation = 45)

# 设置图片的标题
ax.set_title("The total number of objects = {} in {} images".format(
    np.sum(list(statistical_results.values())),len(os.listdir(xml_dir))
))

plt.show()

image-20210124110831876

0

评论 (0)

打卡
取消