add label_list argument (#20)
This commit is contained in:
parent
a47ee2c816
commit
9370343e5e
|
@ -2,4 +2,4 @@
|
||||||
#
|
#
|
||||||
# SPDX-License-Identifier: MIT
|
# SPDX-License-Identifier: MIT
|
||||||
|
|
||||||
__version__ = '0.1.0'
|
__version__ = '0.1.1'
|
||||||
|
|
|
@ -39,6 +39,14 @@ def run():
|
||||||
help='The default output format for labelme2yolo is "polygon".'
|
help='The default output format for labelme2yolo is "polygon".'
|
||||||
' However, you can choose to output in bbox format by specifying the "bbox" option.',
|
' However, you can choose to output in bbox format by specifying the "bbox" option.',
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--label_list",
|
||||||
|
type=str,
|
||||||
|
nargs="+",
|
||||||
|
default=None,
|
||||||
|
help="The ordered label list, for example --label_list cat dog",
|
||||||
|
required=False,
|
||||||
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
@ -46,7 +54,8 @@ def run():
|
||||||
parser.print_help()
|
parser.print_help()
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
convertor = Labelme2YOLO(args.json_dir, args.output_format)
|
convertor = Labelme2YOLO(
|
||||||
|
args.json_dir, args.output_format, args.label_list)
|
||||||
|
|
||||||
if args.json_name is None:
|
if args.json_name is None:
|
||||||
convertor.convert(val_size=args.val_size, test_size=args.test_size)
|
convertor.convert(val_size=args.val_size, test_size=args.test_size)
|
||||||
|
|
|
@ -80,7 +80,7 @@ def img_data_to_png_data(img_data):
|
||||||
return f_in.read()
|
return f_in.read()
|
||||||
|
|
||||||
|
|
||||||
def get_label_id_map(json_dir):
|
def get_label_id_map(json_dir: str):
|
||||||
label_set = set()
|
label_set = set()
|
||||||
|
|
||||||
for file_name in os.listdir(json_dir):
|
for file_name in os.listdir(json_dir):
|
||||||
|
@ -131,11 +131,17 @@ def save_yolo_image(json_data, json_name, image_dir_path, target_dir):
|
||||||
|
|
||||||
class Labelme2YOLO(object):
|
class Labelme2YOLO(object):
|
||||||
|
|
||||||
def __init__(self, json_dir, output_format):
|
def __init__(self, json_dir, output_format, label_list):
|
||||||
self._json_dir = json_dir
|
self._json_dir = json_dir
|
||||||
self._output_format = output_format
|
self._output_format = output_format
|
||||||
|
self._label_list = label_list
|
||||||
|
|
||||||
self._label_id_map = get_label_id_map(self._json_dir)
|
if label_list:
|
||||||
|
self._label_id_map = {label: label_id
|
||||||
|
for label_id, label in enumerate(label_list)}
|
||||||
|
else:
|
||||||
|
self._label_id_map = get_label_id_map(self._json_dir)
|
||||||
|
self._label_list = [label for label in self._label_id_map.keys()]
|
||||||
|
|
||||||
def _make_train_val_dir(self):
|
def _make_train_val_dir(self):
|
||||||
self._label_dir_path = os.path.join(self._json_dir,
|
self._label_dir_path = os.path.join(self._json_dir,
|
||||||
|
@ -275,7 +281,10 @@ class Labelme2YOLO(object):
|
||||||
yolo_w = round(float(obj_w / img_w), 6)
|
yolo_w = round(float(obj_w / img_w), 6)
|
||||||
yolo_h = round(float(obj_h / img_h), 6)
|
yolo_h = round(float(obj_h / img_h), 6)
|
||||||
|
|
||||||
label_id = self._label_id_map[shape['label']]
|
if shape['label'] in self._label_id_map:
|
||||||
|
label_id = self._label_id_map[shape['label']]
|
||||||
|
else:
|
||||||
|
print('label %s not in %s' % shape['label'], self._label_list)
|
||||||
|
|
||||||
return label_id, yolo_center_x, yolo_center_y, yolo_w, yolo_h
|
return label_id, yolo_center_x, yolo_center_y, yolo_w, yolo_h
|
||||||
|
|
||||||
|
@ -290,7 +299,11 @@ class Labelme2YOLO(object):
|
||||||
points = extend_point_list(points)
|
points = extend_point_list(points)
|
||||||
if self._output_format == "bbox":
|
if self._output_format == "bbox":
|
||||||
points = extend_point_list(points, "bbox")
|
points = extend_point_list(points, "bbox")
|
||||||
label_id = self._label_id_map[shape['label']]
|
|
||||||
|
if shape['label'] in self._label_id_map:
|
||||||
|
label_id = self._label_id_map[shape['label']]
|
||||||
|
else:
|
||||||
|
print('label %s not in %s' % shape['label'], self._label_list)
|
||||||
|
|
||||||
return label_id, points.tolist()
|
return label_id, points.tolist()
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue