add label_list argument (#20)
This commit is contained in:
parent
a47ee2c816
commit
9370343e5e
|
@ -2,4 +2,4 @@
|
|||
#
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
__version__ = '0.1.0'
|
||||
__version__ = '0.1.1'
|
||||
|
|
|
@ -39,6 +39,14 @@ def run():
|
|||
help='The default output format for labelme2yolo is "polygon".'
|
||||
' However, you can choose to output in bbox format by specifying the "bbox" option.',
|
||||
)
|
||||
parser.add_argument(
|
||||
"--label_list",
|
||||
type=str,
|
||||
nargs="+",
|
||||
default=None,
|
||||
help="The ordered label list, for example --label_list cat dog",
|
||||
required=False,
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
|
@ -46,7 +54,8 @@ def run():
|
|||
parser.print_help()
|
||||
return 0
|
||||
|
||||
convertor = Labelme2YOLO(args.json_dir, args.output_format)
|
||||
convertor = Labelme2YOLO(
|
||||
args.json_dir, args.output_format, args.label_list)
|
||||
|
||||
if args.json_name is None:
|
||||
convertor.convert(val_size=args.val_size, test_size=args.test_size)
|
||||
|
|
|
@ -80,7 +80,7 @@ def img_data_to_png_data(img_data):
|
|||
return f_in.read()
|
||||
|
||||
|
||||
def get_label_id_map(json_dir):
|
||||
def get_label_id_map(json_dir: str):
|
||||
label_set = set()
|
||||
|
||||
for file_name in os.listdir(json_dir):
|
||||
|
@ -131,11 +131,17 @@ def save_yolo_image(json_data, json_name, image_dir_path, target_dir):
|
|||
|
||||
class Labelme2YOLO(object):
|
||||
|
||||
def __init__(self, json_dir, output_format):
|
||||
def __init__(self, json_dir, output_format, label_list):
|
||||
self._json_dir = json_dir
|
||||
self._output_format = output_format
|
||||
self._label_list = label_list
|
||||
|
||||
self._label_id_map = get_label_id_map(self._json_dir)
|
||||
if label_list:
|
||||
self._label_id_map = {label: label_id
|
||||
for label_id, label in enumerate(label_list)}
|
||||
else:
|
||||
self._label_id_map = get_label_id_map(self._json_dir)
|
||||
self._label_list = [label for label in self._label_id_map.keys()]
|
||||
|
||||
def _make_train_val_dir(self):
|
||||
self._label_dir_path = os.path.join(self._json_dir,
|
||||
|
@ -275,7 +281,10 @@ class Labelme2YOLO(object):
|
|||
yolo_w = round(float(obj_w / img_w), 6)
|
||||
yolo_h = round(float(obj_h / img_h), 6)
|
||||
|
||||
label_id = self._label_id_map[shape['label']]
|
||||
if shape['label'] in self._label_id_map:
|
||||
label_id = self._label_id_map[shape['label']]
|
||||
else:
|
||||
print('label %s not in %s' % shape['label'], self._label_list)
|
||||
|
||||
return label_id, yolo_center_x, yolo_center_y, yolo_w, yolo_h
|
||||
|
||||
|
@ -290,7 +299,11 @@ class Labelme2YOLO(object):
|
|||
points = extend_point_list(points)
|
||||
if self._output_format == "bbox":
|
||||
points = extend_point_list(points, "bbox")
|
||||
label_id = self._label_id_map[shape['label']]
|
||||
|
||||
if shape['label'] in self._label_id_map:
|
||||
label_id = self._label_id_map[shape['label']]
|
||||
else:
|
||||
print('label %s not in %s' % shape['label'], self._label_list)
|
||||
|
||||
return label_id, points.tolist()
|
||||
|
||||
|
|
Loading…
Reference in New Issue