remove scikit-learn dependence

This commit is contained in:
Wang Xin 2023-10-04 16:28:20 +08:00
parent 10dacce294
commit 9f2f443dab
5 changed files with 60 additions and 52 deletions

View File

@ -6,7 +6,7 @@ build-backend = "hatchling.build"
name = "labelme2yolo"
description = "This script converts the JSON format output by LabelMe to the text format required by YOLO serirs."
readme = "README.md"
requires-python = ">=3.7"
requires-python = ">=3.8"
license = "MIT"
keywords = []
authors = [
@ -15,7 +15,6 @@ authors = [
classifiers = [
"Development Status :: 4 - Beta",
"Programming Language :: Python",
"Programming Language :: Python :: 3.7",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
@ -25,7 +24,6 @@ classifiers = [
dependencies = [
"opencv-python>=4.1.2",
"Pillow>=9.2,<10.1",
"scikit-learn>=1.1.1,<1.4.0",
"numpy>=1.23.1,<1.27.0"
]
dynamic = ["version"]
@ -54,7 +52,7 @@ cov = "pytest --cov-report=term-missing --cov-config=pyproject.toml --cov=labelm
no-cov = "cov --no-cov"
[[tool.hatch.envs.test.matrix]]
python = ["37", "38", "39", "310"]
python = ["38", "39", "310"]
[tool.coverage.run]
branch = true

View File

@ -1,4 +1,3 @@
opencv-python
Pillow
scikit-learn
numpy

View File

@ -4,4 +4,4 @@
'''
about version
'''
__version__ = '0.1.2'
__version__ = '0.1.3'

View File

@ -21,14 +21,14 @@ def run():
"--val_size",
type=float,
nargs="?",
default=None,
default=0.2,
help="Please input the validation dataset size, for example 0.1.",
)
parser.add_argument(
"--test_size",
type=float,
nargs="?",
default=None,
default=0.0,
help="Please input the test dataset size, for example 0.1.",
)
parser.add_argument(

View File

@ -9,25 +9,38 @@ import io
import json
import math
import os
import random
import shutil
from collections import OrderedDict
from multiprocessing import Pool
import cv2
import numpy as np
import PIL.ExifTags
import PIL.Image
import PIL.ImageOps
from sklearn.model_selection import train_test_split
import cv2
import numpy as np
random.seed(12345678)
np.random.seed(12345678)
# number of LabelMe2YOLO multiprocessing threads
NUM_THREADS = max(1, os.cpu_count() - 1)
def train_test_split(dataset_index, test_size=0.2):
test_size = min(max(0.0, test_size), 1.0)
total_size = len(dataset_index)
train_size = int(math.ceil(total_size * (1.0 - test_size)))
random.shuffle(dataset_index)
train_index = dataset_index[:train_size]
test_index = dataset_index[train_size:]
return train_index, test_index
# copy form https://github.com/wkentaro/labelme/blob/main/labelme/utils/image.py
def img_data_to_pil(img_data):
'''Convert img_data(byte) to PIL.Image'''
"""Convert img_data(byte) to PIL.Image"""
file = io.BytesIO()
file.write(img_data)
img_pil = PIL.Image.open(file)
@ -36,7 +49,7 @@ def img_data_to_pil(img_data):
# copy form https://github.com/wkentaro/labelme/blob/main/labelme/utils/image.py
def img_data_to_arr(img_data):
'''Convert img_data(byte) to numpy.ndarray'''
"""Convert img_data(byte) to numpy.ndarray"""
img_pil = img_data_to_pil(img_data)
img_arr = np.array(img_pil)
return img_arr
@ -44,7 +57,7 @@ def img_data_to_arr(img_data):
# copy form https://github.com/wkentaro/labelme/blob/main/labelme/utils/image.py
def img_b64_to_arr(img_b64):
'''Convert img_b64(str) to numpy.ndarray'''
"""Convert img_b64(str) to numpy.ndarray"""
img_data = base64.b64decode(img_b64)
img_arr = img_data_to_arr(img_data)
return img_arr
@ -52,7 +65,7 @@ def img_b64_to_arr(img_b64):
# copy form https://github.com/wkentaro/labelme/blob/main/labelme/utils/image.py
def img_pil_to_data(img_pil):
'''Convert PIL.Image to img_data(byte)'''
"""Convert PIL.Image to img_data(byte)"""
file = io.BytesIO()
img_pil.save(file, format="PNG")
img_data = file.getvalue()
@ -61,7 +74,7 @@ def img_pil_to_data(img_pil):
# copy form https://github.com/wkentaro/labelme/blob/main/labelme/utils/image.py
def img_arr_to_b64(img_arr):
'''Convert numpy.ndarray to img_b64(str)'''
"""Convert numpy.ndarray to img_b64(str)"""
img_pil = PIL.Image.fromarray(img_arr)
file = io.BytesIO()
img_pil.save(file, format="PNG")
@ -72,7 +85,7 @@ def img_arr_to_b64(img_arr):
# copy form https://github.com/wkentaro/labelme/blob/main/labelme/utils/image.py
def img_data_to_png_data(img_data):
'''Convert img_data(byte) to png_data(byte)'''
"""Convert img_data(byte) to png_data(byte)"""
with io.BytesIO() as f_out:
f_out.write(img_data)
img = PIL.Image.open(f_out)
@ -84,7 +97,7 @@ def img_data_to_png_data(img_data):
def get_label_id_map(json_dir: str):
'''Get label id map from json files in json_dir'''
"""Get label id map from json files in json_dir"""
label_set = set()
for file_name in os.listdir(json_dir):
@ -99,26 +112,26 @@ def get_label_id_map(json_dir: str):
def extend_point_list(point_list, out_format="polygon"):
'''Extend point list to polygon or bbox'''
xmin = min(float(point) for point in point_list[::2])
xmax = max(float(point) for point in point_list[::2])
ymin = min(float(point) for point in point_list[1::2])
ymax = max(float(point) for point in point_list[1::2])
"""Extend point list to polygon or bbox"""
x_min = min(float(point) for point in point_list[::2])
x_max = max(float(point) for point in point_list[::2])
y_min = min(float(point) for point in point_list[1::2])
y_max = max(float(point) for point in point_list[1::2])
if out_format == "bbox":
x_i = xmin
y_i = ymin
w_i = xmax - xmin
h_i = ymax - ymin
x_i = x_min
y_i = y_min
w_i = x_max - x_min
h_i = y_max - y_min
x_i = x_i + w_i / 2
y_i = y_i + h_i / 2
return np.array([x_i, y_i, w_i, h_i])
return np.array([xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax])
return np.array([x_min, y_min, x_max, y_min, x_max, y_max, x_min, y_max])
def save_yolo_label(json_name, label_dir_path, target_dir, yolo_obj_list):
'''Save yolo label to txt file'''
"""Save yolo label to txt file"""
txt_path = os.path.join(label_dir_path,
target_dir,
json_name.replace(".json", ".txt"))
@ -132,7 +145,7 @@ def save_yolo_label(json_name, label_dir_path, target_dir, yolo_obj_list):
def save_yolo_image(json_data, json_path, image_dir_path, target_dir):
'''Save yolo image to image_dir_path/target_dir'''
"""Save yolo image to image_dir_path/target_dir"""
json_name = os.path.basename(json_path)
img_name = json_name.replace(".json", ".png")
@ -153,7 +166,7 @@ def save_yolo_image(json_data, json_path, image_dir_path, target_dir):
class Labelme2YOLO:
'''Labelme to YOLO format converter'''
"""Labelme to YOLO format converter"""
def __init__(self, json_dir, output_format, label_list):
self._json_dir = json_dir
@ -187,7 +200,7 @@ class Labelme2YOLO:
os.makedirs(yolo_path)
def _get_dataset_part_json_names(self, dataset_part: str):
'''Get json names in dataset_part folder'''
"""Get json names in dataset_part folder"""
set_folder = os.path.join(self._json_dir, dataset_part)
json_names = []
for sample_name in os.listdir(set_folder):
@ -197,34 +210,35 @@ class Labelme2YOLO:
return json_names
def _train_test_split(self, folders, json_names, val_size, test_size):
'''Split json names to train, val, test'''
"""Split json names to train, val, test"""
if (len(folders) > 0 and
'train' in folders and
'val' in folders and
'test' in folders):
train_json_names = self._get_dataset_part_json_names('train')
val_json_names = self._get_dataset_part_json_names('val')
test_json_names = self._get_dataset_part_json_names('test')
return train_json_names, val_json_names, test_json_names
train_idxs, val_idxs = train_test_split(range(len(json_names)),
total_size = len(json_names)
dataset_index = list(range(total_size))
train_ids, val_ids = train_test_split(dataset_index,
test_size=val_size)
test_idxs = []
test_ids = []
if test_size is None:
test_size = 0.0
if test_size > 1e-8:
train_idxs, test_idxs = train_test_split(
train_idxs, test_size=test_size / (1 - val_size))
train_json_names = [json_names[train_idx] for train_idx in train_idxs]
val_json_names = [json_names[val_idx] for val_idx in val_idxs]
test_json_names = [json_names[test_idx] for test_idx in test_idxs]
train_ids, test_ids = train_test_split(
train_ids, test_size=test_size / (1 - val_size))
train_json_names = [json_names[train_idx] for train_idx in train_ids]
val_json_names = [json_names[val_idx] for val_idx in val_ids]
test_json_names = [json_names[test_idx] for test_idx in test_ids]
return train_json_names, val_json_names, test_json_names
def convert(self, val_size, test_size):
'''Convert labelme format to yolo format'''
"""Convert labelme format to yolo format"""
json_names = [file_name for file_name in os.listdir(self._json_dir)
if os.path.isfile(os.path.join(self._json_dir, file_name)) and
file_name.endswith('.json')]
@ -317,10 +331,9 @@ class Labelme2YOLO:
if shape['label'] in self._label_id_map:
label_id = self._label_id_map[shape['label']]
else:
print(f"label {shape['label']} not in {self._label_list}")
return label_id, yolo_center_x, yolo_center_y, yolo_w, yolo_h
else:
raise f"label {shape['label']} not in {self._label_list}"
def _get_other_shape_yolo_object(self, shape, img_h, img_w):
@ -336,17 +349,15 @@ class Labelme2YOLO:
if shape['label'] in self._label_id_map:
label_id = self._label_id_map[shape['label']]
else:
print(f"label {shape['label']} not in {self._label_list}")
return label_id, points.tolist()
else:
raise f"label {shape['label']} not in {self._label_list}"
def _save_dataset_yaml(self):
yaml_path = os.path.join(
self._json_dir, 'YOLODataset/', 'dataset.yaml')
with open(yaml_path, 'w+', encoding="utf-8") as yaml_file:
train_dir = os.path.join(self._image_dir_path, 'train/')
val_dir = os.path.join(self._image_dir_path, 'val/')
test_dir = os.path.join(self._image_dir_path, 'test/')