diff --git a/labelme2yolo.py b/labelme2yolo.py index 28de13c..b338241 100644 --- a/labelme2yolo.py +++ b/labelme2yolo.py @@ -78,7 +78,9 @@ class Labelme2YOLO(object): train_idxs, val_idxs = train_test_split(range(len(json_names)), test_size=val_size) tmp_train_len = len(train_idxs) - train_idxs, test_idxs = train_test_split(range(tmp_train_len), test_size=test_size / (1 - val_size)) + test_idxs = [] + if test_size > 1e-8: + train_idxs, test_idxs = train_test_split(range(tmp_train_len), test_size=test_size / (1 - val_size)) train_json_names = [json_names[train_idx] for train_idx in train_idxs] val_json_names = [json_names[val_idx] for val_idx in val_idxs] test_json_names = [json_names[test_idx] for test_idx in test_idxs]