Update labelme2yolo.py
This commit is contained in:
parent
4f78f1a9c2
commit
4f4822786b
|
@ -8,14 +8,16 @@ import sys
|
||||||
import argparse
|
import argparse
|
||||||
import shutil
|
import shutil
|
||||||
import math
|
import math
|
||||||
|
from collections import OrderedDict
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import cv2
|
import cv2
|
||||||
import PIL.Image
|
import PIL.Image
|
||||||
|
|
||||||
from sklearn.model_selection import train_test_split
|
from sklearn.model_selection import train_test_split
|
||||||
from labelme import utils
|
from labelme import utils
|
||||||
|
|
||||||
|
|
||||||
class Labelme2YOLO(object):
|
class Labelme2YOLO(object):
|
||||||
|
|
||||||
def __init__(self, json_dir):
|
def __init__(self, json_dir):
|
||||||
|
@ -48,7 +50,8 @@ class Labelme2YOLO(object):
|
||||||
for shape in data['shapes']:
|
for shape in data['shapes']:
|
||||||
label_set.add(shape['label'])
|
label_set.add(shape['label'])
|
||||||
|
|
||||||
return {label: label_id for label_id, label in enumerate(label_set)}
|
return OrderedDict([(label, label_id) \
|
||||||
|
for label_id, label in enumerate(label_set)])
|
||||||
|
|
||||||
def _train_test_split(self, folders, json_names, val_size):
|
def _train_test_split(self, folders, json_names, val_size):
|
||||||
if len(folders) > 0 and 'train' in folders and 'val' in folders:
|
if len(folders) > 0 and 'train' in folders and 'val' in folders:
|
||||||
|
@ -101,6 +104,9 @@ class Labelme2YOLO(object):
|
||||||
self._label_dir_path,
|
self._label_dir_path,
|
||||||
target_dir,
|
target_dir,
|
||||||
yolo_obj_list)
|
yolo_obj_list)
|
||||||
|
|
||||||
|
print('Generating dataset.yaml file ...')
|
||||||
|
self._save_dataset_yaml()
|
||||||
|
|
||||||
def convert_one(self, json_name):
|
def convert_one(self, json_name):
|
||||||
json_path = os.path.join(self._json_dir, json_name)
|
json_path = os.path.join(self._json_dir, json_name)
|
||||||
|
@ -179,7 +185,7 @@ class Labelme2YOLO(object):
|
||||||
if yolo_obj_idx + 1 != len(yolo_obj_list) else \
|
if yolo_obj_idx + 1 != len(yolo_obj_list) else \
|
||||||
'%s %s %s %s %s' % yolo_obj
|
'%s %s %s %s %s' % yolo_obj
|
||||||
f.write(yolo_obj_line)
|
f.write(yolo_obj_line)
|
||||||
|
|
||||||
def _save_yolo_image(self, json_data, json_name, image_dir_path, target_dir):
|
def _save_yolo_image(self, json_data, json_name, image_dir_path, target_dir):
|
||||||
img_name = json_name.replace('.json', '.png')
|
img_name = json_name.replace('.json', '.png')
|
||||||
img_path = os.path.join(image_dir_path, target_dir,img_name)
|
img_path = os.path.join(image_dir_path, target_dir,img_name)
|
||||||
|
@ -189,6 +195,23 @@ class Labelme2YOLO(object):
|
||||||
PIL.Image.fromarray(img).save(img_path)
|
PIL.Image.fromarray(img).save(img_path)
|
||||||
|
|
||||||
return img_path
|
return img_path
|
||||||
|
|
||||||
|
def _save_dataset_yaml(self):
|
||||||
|
yaml_path = os.path.join(self._json_dir, 'YOLODataset/', 'dataset.yaml')
|
||||||
|
|
||||||
|
with open(yaml_path, 'w+') as yaml_file:
|
||||||
|
yaml_file.write('train: %s\n' % \
|
||||||
|
os.path.join(self._image_dir_path, 'train/'))
|
||||||
|
yaml_file.write('val: %s\n\n' % \
|
||||||
|
os.path.join(self._image_dir_path, 'val/'))
|
||||||
|
yaml_file.write('nc: %i\n\n' % len(self._label_id_map))
|
||||||
|
|
||||||
|
names_str = ''
|
||||||
|
for label, _ in self._label_id_map.items():
|
||||||
|
names_str += "'%s', " % label
|
||||||
|
names_str = names_str.rstrip(', ')
|
||||||
|
yaml_file.write('names: [%s]' % names_str)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
|
|
Loading…
Reference in New Issue