diff --git a/src/labelme2yolo/__about__.py b/src/labelme2yolo/__about__.py index b6c45fc..b6d4217 100644 --- a/src/labelme2yolo/__about__.py +++ b/src/labelme2yolo/__about__.py @@ -2,4 +2,4 @@ # # SPDX-License-Identifier: MIT -__version__ = '0.1.0' +__version__ = '0.1.1' diff --git a/src/labelme2yolo/cli/__init__.py b/src/labelme2yolo/cli/__init__.py index 85a9d9b..5892262 100644 --- a/src/labelme2yolo/cli/__init__.py +++ b/src/labelme2yolo/cli/__init__.py @@ -39,6 +39,14 @@ def run(): help='The default output format for labelme2yolo is "polygon".' ' However, you can choose to output in bbox format by specifying the "bbox" option.', ) + parser.add_argument( + "--label_list", + type=str, + nargs="+", + default=None, + help="The ordered label list, for example --label_list cat dog", + required=False, + ) args = parser.parse_args() @@ -46,7 +54,8 @@ def run(): parser.print_help() return 0 - convertor = Labelme2YOLO(args.json_dir, args.output_format) + convertor = Labelme2YOLO( + args.json_dir, args.output_format, args.label_list) if args.json_name is None: convertor.convert(val_size=args.val_size, test_size=args.test_size) diff --git a/src/labelme2yolo/l2y.py b/src/labelme2yolo/l2y.py index 7dfb17a..71631c4 100644 --- a/src/labelme2yolo/l2y.py +++ b/src/labelme2yolo/l2y.py @@ -80,7 +80,7 @@ def img_data_to_png_data(img_data): return f_in.read() -def get_label_id_map(json_dir): +def get_label_id_map(json_dir: str): label_set = set() for file_name in os.listdir(json_dir): @@ -131,11 +131,17 @@ def save_yolo_image(json_data, json_name, image_dir_path, target_dir): class Labelme2YOLO(object): - def __init__(self, json_dir, output_format): + def __init__(self, json_dir, output_format, label_list): self._json_dir = json_dir self._output_format = output_format + self._label_list = label_list - self._label_id_map = get_label_id_map(self._json_dir) + if label_list: + self._label_id_map = {label: label_id + for label_id, label in enumerate(label_list)} + else: + self._label_id_map = get_label_id_map(self._json_dir) + self._label_list = [label for label in self._label_id_map.keys()] def _make_train_val_dir(self): self._label_dir_path = os.path.join(self._json_dir, @@ -275,7 +281,10 @@ class Labelme2YOLO(object): yolo_w = round(float(obj_w / img_w), 6) yolo_h = round(float(obj_h / img_h), 6) - label_id = self._label_id_map[shape['label']] + if shape['label'] in self._label_id_map: + label_id = self._label_id_map[shape['label']] + else: + print('label %s not in %s' % shape['label'], self._label_list) return label_id, yolo_center_x, yolo_center_y, yolo_w, yolo_h @@ -290,7 +299,11 @@ class Labelme2YOLO(object): points = extend_point_list(points) if self._output_format == "bbox": points = extend_point_list(points, "bbox") - label_id = self._label_id_map[shape['label']] + + if shape['label'] in self._label_id_map: + label_id = self._label_id_map[shape['label']] + else: + print('label %s not in %s' % shape['label'], self._label_list) return label_id, points.tolist()