Update labelme2yolo.py

This commit is contained in:
rooneysh 2021-08-20 09:52:44 +08:00 committed by GitHub
parent 4f78f1a9c2
commit 4f4822786b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 26 additions and 3 deletions

View File

@ -8,14 +8,16 @@ import sys
import argparse
import shutil
import math
from collections import OrderedDict
import json
import cv2
import PIL.Image
from sklearn.model_selection import train_test_split
from labelme import utils
class Labelme2YOLO(object):
def __init__(self, json_dir):
@ -48,7 +50,8 @@ class Labelme2YOLO(object):
for shape in data['shapes']:
label_set.add(shape['label'])
return {label: label_id for label_id, label in enumerate(label_set)}
return OrderedDict([(label, label_id) \
for label_id, label in enumerate(label_set)])
def _train_test_split(self, folders, json_names, val_size):
if len(folders) > 0 and 'train' in folders and 'val' in folders:
@ -101,6 +104,9 @@ class Labelme2YOLO(object):
self._label_dir_path,
target_dir,
yolo_obj_list)
print('Generating dataset.yaml file ...')
self._save_dataset_yaml()
def convert_one(self, json_name):
json_path = os.path.join(self._json_dir, json_name)
@ -179,7 +185,7 @@ class Labelme2YOLO(object):
if yolo_obj_idx + 1 != len(yolo_obj_list) else \
'%s %s %s %s %s' % yolo_obj
f.write(yolo_obj_line)
def _save_yolo_image(self, json_data, json_name, image_dir_path, target_dir):
img_name = json_name.replace('.json', '.png')
img_path = os.path.join(image_dir_path, target_dir,img_name)
@ -189,6 +195,23 @@ class Labelme2YOLO(object):
PIL.Image.fromarray(img).save(img_path)
return img_path
def _save_dataset_yaml(self):
yaml_path = os.path.join(self._json_dir, 'YOLODataset/', 'dataset.yaml')
with open(yaml_path, 'w+') as yaml_file:
yaml_file.write('train: %s\n' % \
os.path.join(self._image_dir_path, 'train/'))
yaml_file.write('val: %s\n\n' % \
os.path.join(self._image_dir_path, 'val/'))
yaml_file.write('nc: %i\n\n' % len(self._label_id_map))
names_str = ''
for label, _ in self._label_id_map.items():
names_str += "'%s', " % label
names_str = names_str.rstrip(', ')
yaml_file.write('names: [%s]' % names_str)
if __name__ == '__main__':
parser = argparse.ArgumentParser()