Update labelme2yolo.py

This commit is contained in:
rooneysh 2021-08-20 09:52:44 +08:00 committed by GitHub
parent 4f78f1a9c2
commit 4f4822786b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 26 additions and 3 deletions

View File

@ -8,14 +8,16 @@ import sys
import argparse import argparse
import shutil import shutil
import math import math
from collections import OrderedDict
import json import json
import cv2 import cv2
import PIL.Image import PIL.Image
from sklearn.model_selection import train_test_split from sklearn.model_selection import train_test_split
from labelme import utils from labelme import utils
class Labelme2YOLO(object): class Labelme2YOLO(object):
def __init__(self, json_dir): def __init__(self, json_dir):
@ -48,7 +50,8 @@ class Labelme2YOLO(object):
for shape in data['shapes']: for shape in data['shapes']:
label_set.add(shape['label']) label_set.add(shape['label'])
return {label: label_id for label_id, label in enumerate(label_set)} return OrderedDict([(label, label_id) \
for label_id, label in enumerate(label_set)])
def _train_test_split(self, folders, json_names, val_size): def _train_test_split(self, folders, json_names, val_size):
if len(folders) > 0 and 'train' in folders and 'val' in folders: if len(folders) > 0 and 'train' in folders and 'val' in folders:
@ -101,6 +104,9 @@ class Labelme2YOLO(object):
self._label_dir_path, self._label_dir_path,
target_dir, target_dir,
yolo_obj_list) yolo_obj_list)
print('Generating dataset.yaml file ...')
self._save_dataset_yaml()
def convert_one(self, json_name): def convert_one(self, json_name):
json_path = os.path.join(self._json_dir, json_name) json_path = os.path.join(self._json_dir, json_name)
@ -179,7 +185,7 @@ class Labelme2YOLO(object):
if yolo_obj_idx + 1 != len(yolo_obj_list) else \ if yolo_obj_idx + 1 != len(yolo_obj_list) else \
'%s %s %s %s %s' % yolo_obj '%s %s %s %s %s' % yolo_obj
f.write(yolo_obj_line) f.write(yolo_obj_line)
def _save_yolo_image(self, json_data, json_name, image_dir_path, target_dir): def _save_yolo_image(self, json_data, json_name, image_dir_path, target_dir):
img_name = json_name.replace('.json', '.png') img_name = json_name.replace('.json', '.png')
img_path = os.path.join(image_dir_path, target_dir,img_name) img_path = os.path.join(image_dir_path, target_dir,img_name)
@ -189,6 +195,23 @@ class Labelme2YOLO(object):
PIL.Image.fromarray(img).save(img_path) PIL.Image.fromarray(img).save(img_path)
return img_path return img_path
def _save_dataset_yaml(self):
yaml_path = os.path.join(self._json_dir, 'YOLODataset/', 'dataset.yaml')
with open(yaml_path, 'w+') as yaml_file:
yaml_file.write('train: %s\n' % \
os.path.join(self._image_dir_path, 'train/'))
yaml_file.write('val: %s\n\n' % \
os.path.join(self._image_dir_path, 'val/'))
yaml_file.write('nc: %i\n\n' % len(self._label_id_map))
names_str = ''
for label, _ in self._label_id_map.items():
names_str += "'%s', " % label
names_str = names_str.rstrip(', ')
yaml_file.write('names: [%s]' % names_str)
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()