diff --git a/src/labelme2yolo/l2y.py b/src/labelme2yolo/l2y.py index 0be1d1f..850e177 100644 --- a/src/labelme2yolo/l2y.py +++ b/src/labelme2yolo/l2y.py @@ -28,6 +28,7 @@ NUM_THREADS = max(1, os.cpu_count() - 1) def train_test_split(dataset_index, test_size=0.2): + """Split dataset into train set and test set with test_size""" test_size = min(max(0.0, test_size), 1.0) total_size = len(dataset_index) train_size = int(math.ceil(total_size * (1.0 - test_size))) @@ -332,8 +333,8 @@ class Labelme2YOLO: if shape['label'] in self._label_id_map: label_id = self._label_id_map[shape['label']] return label_id, yolo_center_x, yolo_center_y, yolo_w, yolo_h - else: - raise f"label {shape['label']} not in {self._label_list}" + + raise f"label {shape['label']} not in {self._label_list}" def _get_other_shape_yolo_object(self, shape, img_h, img_w): @@ -350,8 +351,8 @@ class Labelme2YOLO: if shape['label'] in self._label_id_map: label_id = self._label_id_map[shape['label']] return label_id, points.tolist() - else: - raise f"label {shape['label']} not in {self._label_list}" + + raise f"label {shape['label']} not in {self._label_list}" def _save_dataset_yaml(self): yaml_path = os.path.join(