diff --git a/src/labelme2yolo/__about__.py b/src/labelme2yolo/__about__.py index 485f66a..0822a15 100644 --- a/src/labelme2yolo/__about__.py +++ b/src/labelme2yolo/__about__.py @@ -4,4 +4,4 @@ """ about version """ -__version__ = "0.1.4" +__version__ = "0.1.5" diff --git a/src/labelme2yolo/l2y.py b/src/labelme2yolo/l2y.py index 81dd6f3..92d1a1b 100644 --- a/src/labelme2yolo/l2y.py +++ b/src/labelme2yolo/l2y.py @@ -15,6 +15,7 @@ import shutil import uuid import logging +from multiprocessing import Pool import PIL.ExifTags import PIL.Image import PIL.ImageOps @@ -163,6 +164,20 @@ class Labelme2YOLO: self._label_id_map = { label: label_id for label_id, label in enumerate(label_list) } + else: + # get label list from json files for parallel processing + json_files = glob.glob( + os.path.join(self._json_dir, "**", "*.json"), recursive=True + ) + for json_file in json_files: + with open(json_file, encoding="utf-8") as file: + json_data = json.load(file) + for shape in json_data["shapes"]: + if shape["label"] not in self._label_list: + self._label_list.append(shape["label"]) + self._label_id_map = { + label: label_id for label_id, label in enumerate(self._label_list) + } def _update_id_map(self, label: str): if label not in self._label_list: @@ -230,8 +245,9 @@ class Labelme2YOLO: names = (train_json_names, val_json_names, test_json_names) for target_dir, json_names in zip(dirs, names): logger.info("Converting %s set ...", target_dir) - for json_name in tqdm.tqdm(json_names): - self.covert_json_to_text(target_dir, json_name) + with Pool(os.cpu_count() - 1) as pool: + for json_name in tqdm.tqdm(json_names): + pool.apply(self.covert_json_to_text, (target_dir, json_name)) self._save_dataset_yaml()