remove scikit-learn dependence
This commit is contained in:
parent
10dacce294
commit
9f2f443dab
|
@ -6,7 +6,7 @@ build-backend = "hatchling.build"
|
|||
name = "labelme2yolo"
|
||||
description = "This script converts the JSON format output by LabelMe to the text format required by YOLO serirs."
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.7"
|
||||
requires-python = ">=3.8"
|
||||
license = "MIT"
|
||||
keywords = []
|
||||
authors = [
|
||||
|
@ -15,7 +15,6 @@ authors = [
|
|||
classifiers = [
|
||||
"Development Status :: 4 - Beta",
|
||||
"Programming Language :: Python",
|
||||
"Programming Language :: Python :: 3.7",
|
||||
"Programming Language :: Python :: 3.8",
|
||||
"Programming Language :: Python :: 3.9",
|
||||
"Programming Language :: Python :: 3.10",
|
||||
|
@ -25,7 +24,6 @@ classifiers = [
|
|||
dependencies = [
|
||||
"opencv-python>=4.1.2",
|
||||
"Pillow>=9.2,<10.1",
|
||||
"scikit-learn>=1.1.1,<1.4.0",
|
||||
"numpy>=1.23.1,<1.27.0"
|
||||
]
|
||||
dynamic = ["version"]
|
||||
|
@ -54,7 +52,7 @@ cov = "pytest --cov-report=term-missing --cov-config=pyproject.toml --cov=labelm
|
|||
no-cov = "cov --no-cov"
|
||||
|
||||
[[tool.hatch.envs.test.matrix]]
|
||||
python = ["37", "38", "39", "310"]
|
||||
python = ["38", "39", "310"]
|
||||
|
||||
[tool.coverage.run]
|
||||
branch = true
|
||||
|
|
|
@ -1,4 +1,3 @@
|
|||
opencv-python
|
||||
Pillow
|
||||
scikit-learn
|
||||
numpy
|
||||
|
|
|
@ -4,4 +4,4 @@
|
|||
'''
|
||||
about version
|
||||
'''
|
||||
__version__ = '0.1.2'
|
||||
__version__ = '0.1.3'
|
||||
|
|
|
@ -21,14 +21,14 @@ def run():
|
|||
"--val_size",
|
||||
type=float,
|
||||
nargs="?",
|
||||
default=None,
|
||||
default=0.2,
|
||||
help="Please input the validation dataset size, for example 0.1.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--test_size",
|
||||
type=float,
|
||||
nargs="?",
|
||||
default=None,
|
||||
default=0.0,
|
||||
help="Please input the test dataset size, for example 0.1.",
|
||||
)
|
||||
parser.add_argument(
|
||||
|
|
|
@ -9,25 +9,38 @@ import io
|
|||
import json
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
import shutil
|
||||
from collections import OrderedDict
|
||||
from multiprocessing import Pool
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import PIL.ExifTags
|
||||
import PIL.Image
|
||||
import PIL.ImageOps
|
||||
from sklearn.model_selection import train_test_split
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
random.seed(12345678)
|
||||
np.random.seed(12345678)
|
||||
|
||||
# number of LabelMe2YOLO multiprocessing threads
|
||||
NUM_THREADS = max(1, os.cpu_count() - 1)
|
||||
|
||||
|
||||
def train_test_split(dataset_index, test_size=0.2):
|
||||
test_size = min(max(0.0, test_size), 1.0)
|
||||
total_size = len(dataset_index)
|
||||
train_size = int(math.ceil(total_size * (1.0 - test_size)))
|
||||
random.shuffle(dataset_index)
|
||||
train_index = dataset_index[:train_size]
|
||||
test_index = dataset_index[train_size:]
|
||||
|
||||
return train_index, test_index
|
||||
|
||||
|
||||
# copy form https://github.com/wkentaro/labelme/blob/main/labelme/utils/image.py
|
||||
def img_data_to_pil(img_data):
|
||||
'''Convert img_data(byte) to PIL.Image'''
|
||||
"""Convert img_data(byte) to PIL.Image"""
|
||||
file = io.BytesIO()
|
||||
file.write(img_data)
|
||||
img_pil = PIL.Image.open(file)
|
||||
|
@ -36,7 +49,7 @@ def img_data_to_pil(img_data):
|
|||
|
||||
# copy form https://github.com/wkentaro/labelme/blob/main/labelme/utils/image.py
|
||||
def img_data_to_arr(img_data):
|
||||
'''Convert img_data(byte) to numpy.ndarray'''
|
||||
"""Convert img_data(byte) to numpy.ndarray"""
|
||||
img_pil = img_data_to_pil(img_data)
|
||||
img_arr = np.array(img_pil)
|
||||
return img_arr
|
||||
|
@ -44,7 +57,7 @@ def img_data_to_arr(img_data):
|
|||
|
||||
# copy form https://github.com/wkentaro/labelme/blob/main/labelme/utils/image.py
|
||||
def img_b64_to_arr(img_b64):
|
||||
'''Convert img_b64(str) to numpy.ndarray'''
|
||||
"""Convert img_b64(str) to numpy.ndarray"""
|
||||
img_data = base64.b64decode(img_b64)
|
||||
img_arr = img_data_to_arr(img_data)
|
||||
return img_arr
|
||||
|
@ -52,7 +65,7 @@ def img_b64_to_arr(img_b64):
|
|||
|
||||
# copy form https://github.com/wkentaro/labelme/blob/main/labelme/utils/image.py
|
||||
def img_pil_to_data(img_pil):
|
||||
'''Convert PIL.Image to img_data(byte)'''
|
||||
"""Convert PIL.Image to img_data(byte)"""
|
||||
file = io.BytesIO()
|
||||
img_pil.save(file, format="PNG")
|
||||
img_data = file.getvalue()
|
||||
|
@ -61,7 +74,7 @@ def img_pil_to_data(img_pil):
|
|||
|
||||
# copy form https://github.com/wkentaro/labelme/blob/main/labelme/utils/image.py
|
||||
def img_arr_to_b64(img_arr):
|
||||
'''Convert numpy.ndarray to img_b64(str)'''
|
||||
"""Convert numpy.ndarray to img_b64(str)"""
|
||||
img_pil = PIL.Image.fromarray(img_arr)
|
||||
file = io.BytesIO()
|
||||
img_pil.save(file, format="PNG")
|
||||
|
@ -72,7 +85,7 @@ def img_arr_to_b64(img_arr):
|
|||
|
||||
# copy form https://github.com/wkentaro/labelme/blob/main/labelme/utils/image.py
|
||||
def img_data_to_png_data(img_data):
|
||||
'''Convert img_data(byte) to png_data(byte)'''
|
||||
"""Convert img_data(byte) to png_data(byte)"""
|
||||
with io.BytesIO() as f_out:
|
||||
f_out.write(img_data)
|
||||
img = PIL.Image.open(f_out)
|
||||
|
@ -84,7 +97,7 @@ def img_data_to_png_data(img_data):
|
|||
|
||||
|
||||
def get_label_id_map(json_dir: str):
|
||||
'''Get label id map from json files in json_dir'''
|
||||
"""Get label id map from json files in json_dir"""
|
||||
label_set = set()
|
||||
|
||||
for file_name in os.listdir(json_dir):
|
||||
|
@ -99,26 +112,26 @@ def get_label_id_map(json_dir: str):
|
|||
|
||||
|
||||
def extend_point_list(point_list, out_format="polygon"):
|
||||
'''Extend point list to polygon or bbox'''
|
||||
xmin = min(float(point) for point in point_list[::2])
|
||||
xmax = max(float(point) for point in point_list[::2])
|
||||
ymin = min(float(point) for point in point_list[1::2])
|
||||
ymax = max(float(point) for point in point_list[1::2])
|
||||
"""Extend point list to polygon or bbox"""
|
||||
x_min = min(float(point) for point in point_list[::2])
|
||||
x_max = max(float(point) for point in point_list[::2])
|
||||
y_min = min(float(point) for point in point_list[1::2])
|
||||
y_max = max(float(point) for point in point_list[1::2])
|
||||
|
||||
if out_format == "bbox":
|
||||
x_i = xmin
|
||||
y_i = ymin
|
||||
w_i = xmax - xmin
|
||||
h_i = ymax - ymin
|
||||
x_i = x_min
|
||||
y_i = y_min
|
||||
w_i = x_max - x_min
|
||||
h_i = y_max - y_min
|
||||
x_i = x_i + w_i / 2
|
||||
y_i = y_i + h_i / 2
|
||||
return np.array([x_i, y_i, w_i, h_i])
|
||||
|
||||
return np.array([xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax])
|
||||
return np.array([x_min, y_min, x_max, y_min, x_max, y_max, x_min, y_max])
|
||||
|
||||
|
||||
def save_yolo_label(json_name, label_dir_path, target_dir, yolo_obj_list):
|
||||
'''Save yolo label to txt file'''
|
||||
"""Save yolo label to txt file"""
|
||||
txt_path = os.path.join(label_dir_path,
|
||||
target_dir,
|
||||
json_name.replace(".json", ".txt"))
|
||||
|
@ -132,7 +145,7 @@ def save_yolo_label(json_name, label_dir_path, target_dir, yolo_obj_list):
|
|||
|
||||
|
||||
def save_yolo_image(json_data, json_path, image_dir_path, target_dir):
|
||||
'''Save yolo image to image_dir_path/target_dir'''
|
||||
"""Save yolo image to image_dir_path/target_dir"""
|
||||
json_name = os.path.basename(json_path)
|
||||
img_name = json_name.replace(".json", ".png")
|
||||
|
||||
|
@ -153,7 +166,7 @@ def save_yolo_image(json_data, json_path, image_dir_path, target_dir):
|
|||
|
||||
|
||||
class Labelme2YOLO:
|
||||
'''Labelme to YOLO format converter'''
|
||||
"""Labelme to YOLO format converter"""
|
||||
|
||||
def __init__(self, json_dir, output_format, label_list):
|
||||
self._json_dir = json_dir
|
||||
|
@ -187,7 +200,7 @@ class Labelme2YOLO:
|
|||
os.makedirs(yolo_path)
|
||||
|
||||
def _get_dataset_part_json_names(self, dataset_part: str):
|
||||
'''Get json names in dataset_part folder'''
|
||||
"""Get json names in dataset_part folder"""
|
||||
set_folder = os.path.join(self._json_dir, dataset_part)
|
||||
json_names = []
|
||||
for sample_name in os.listdir(set_folder):
|
||||
|
@ -197,34 +210,35 @@ class Labelme2YOLO:
|
|||
return json_names
|
||||
|
||||
def _train_test_split(self, folders, json_names, val_size, test_size):
|
||||
'''Split json names to train, val, test'''
|
||||
"""Split json names to train, val, test"""
|
||||
if (len(folders) > 0 and
|
||||
'train' in folders and
|
||||
'val' in folders and
|
||||
'train' in folders and
|
||||
'val' in folders and
|
||||
'test' in folders):
|
||||
|
||||
train_json_names = self._get_dataset_part_json_names('train')
|
||||
val_json_names = self._get_dataset_part_json_names('val')
|
||||
test_json_names = self._get_dataset_part_json_names('test')
|
||||
|
||||
return train_json_names, val_json_names, test_json_names
|
||||
|
||||
train_idxs, val_idxs = train_test_split(range(len(json_names)),
|
||||
test_size=val_size)
|
||||
test_idxs = []
|
||||
total_size = len(json_names)
|
||||
dataset_index = list(range(total_size))
|
||||
train_ids, val_ids = train_test_split(dataset_index,
|
||||
test_size=val_size)
|
||||
test_ids = []
|
||||
if test_size is None:
|
||||
test_size = 0.0
|
||||
if test_size > 1e-8:
|
||||
train_idxs, test_idxs = train_test_split(
|
||||
train_idxs, test_size=test_size / (1 - val_size))
|
||||
train_json_names = [json_names[train_idx] for train_idx in train_idxs]
|
||||
val_json_names = [json_names[val_idx] for val_idx in val_idxs]
|
||||
test_json_names = [json_names[test_idx] for test_idx in test_idxs]
|
||||
train_ids, test_ids = train_test_split(
|
||||
train_ids, test_size=test_size / (1 - val_size))
|
||||
train_json_names = [json_names[train_idx] for train_idx in train_ids]
|
||||
val_json_names = [json_names[val_idx] for val_idx in val_ids]
|
||||
test_json_names = [json_names[test_idx] for test_idx in test_ids]
|
||||
|
||||
return train_json_names, val_json_names, test_json_names
|
||||
|
||||
def convert(self, val_size, test_size):
|
||||
'''Convert labelme format to yolo format'''
|
||||
"""Convert labelme format to yolo format"""
|
||||
json_names = [file_name for file_name in os.listdir(self._json_dir)
|
||||
if os.path.isfile(os.path.join(self._json_dir, file_name)) and
|
||||
file_name.endswith('.json')]
|
||||
|
@ -317,10 +331,9 @@ class Labelme2YOLO:
|
|||
|
||||
if shape['label'] in self._label_id_map:
|
||||
label_id = self._label_id_map[shape['label']]
|
||||
return label_id, yolo_center_x, yolo_center_y, yolo_w, yolo_h
|
||||
else:
|
||||
print(f"label {shape['label']} not in {self._label_list}")
|
||||
|
||||
return label_id, yolo_center_x, yolo_center_y, yolo_w, yolo_h
|
||||
raise f"label {shape['label']} not in {self._label_list}"
|
||||
|
||||
def _get_other_shape_yolo_object(self, shape, img_h, img_w):
|
||||
|
||||
|
@ -336,17 +349,15 @@ class Labelme2YOLO:
|
|||
|
||||
if shape['label'] in self._label_id_map:
|
||||
label_id = self._label_id_map[shape['label']]
|
||||
return label_id, points.tolist()
|
||||
else:
|
||||
print(f"label {shape['label']} not in {self._label_list}")
|
||||
|
||||
return label_id, points.tolist()
|
||||
raise f"label {shape['label']} not in {self._label_list}"
|
||||
|
||||
def _save_dataset_yaml(self):
|
||||
yaml_path = os.path.join(
|
||||
self._json_dir, 'YOLODataset/', 'dataset.yaml')
|
||||
|
||||
with open(yaml_path, 'w+', encoding="utf-8") as yaml_file:
|
||||
|
||||
train_dir = os.path.join(self._image_dir_path, 'train/')
|
||||
val_dir = os.path.join(self._image_dir_path, 'val/')
|
||||
test_dir = os.path.join(self._image_dir_path, 'test/')
|
||||
|
|
Loading…
Reference in New Issue