add label_list argument (#20)

This commit is contained in:
Wang Xin 2023-05-04 14:24:16 +08:00 committed by GitHub
parent a47ee2c816
commit 9370343e5e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 29 additions and 7 deletions

View File

@ -2,4 +2,4 @@
#
# SPDX-License-Identifier: MIT
__version__ = '0.1.0'
__version__ = '0.1.1'

View File

@ -39,6 +39,14 @@ def run():
help='The default output format for labelme2yolo is "polygon".'
' However, you can choose to output in bbox format by specifying the "bbox" option.',
)
parser.add_argument(
"--label_list",
type=str,
nargs="+",
default=None,
help="The ordered label list, for example --label_list cat dog",
required=False,
)
args = parser.parse_args()
@ -46,7 +54,8 @@ def run():
parser.print_help()
return 0
convertor = Labelme2YOLO(args.json_dir, args.output_format)
convertor = Labelme2YOLO(
args.json_dir, args.output_format, args.label_list)
if args.json_name is None:
convertor.convert(val_size=args.val_size, test_size=args.test_size)

View File

@ -80,7 +80,7 @@ def img_data_to_png_data(img_data):
return f_in.read()
def get_label_id_map(json_dir):
def get_label_id_map(json_dir: str):
label_set = set()
for file_name in os.listdir(json_dir):
@ -131,11 +131,17 @@ def save_yolo_image(json_data, json_name, image_dir_path, target_dir):
class Labelme2YOLO(object):
def __init__(self, json_dir, output_format):
def __init__(self, json_dir, output_format, label_list):
self._json_dir = json_dir
self._output_format = output_format
self._label_list = label_list
self._label_id_map = get_label_id_map(self._json_dir)
if label_list:
self._label_id_map = {label: label_id
for label_id, label in enumerate(label_list)}
else:
self._label_id_map = get_label_id_map(self._json_dir)
self._label_list = [label for label in self._label_id_map.keys()]
def _make_train_val_dir(self):
self._label_dir_path = os.path.join(self._json_dir,
@ -275,7 +281,10 @@ class Labelme2YOLO(object):
yolo_w = round(float(obj_w / img_w), 6)
yolo_h = round(float(obj_h / img_h), 6)
label_id = self._label_id_map[shape['label']]
if shape['label'] in self._label_id_map:
label_id = self._label_id_map[shape['label']]
else:
print('label %s not in %s' % shape['label'], self._label_list)
return label_id, yolo_center_x, yolo_center_y, yolo_w, yolo_h
@ -290,7 +299,11 @@ class Labelme2YOLO(object):
points = extend_point_list(points)
if self._output_format == "bbox":
points = extend_point_list(points, "bbox")
label_id = self._label_id_map[shape['label']]
if shape['label'] in self._label_id_map:
label_id = self._label_id_map[shape['label']]
else:
print('label %s not in %s' % shape['label'], self._label_list)
return label_id, points.tolist()