cleanup code (#28)

This commit is contained in:
Wang Xin 2023-06-07 20:14:46 +08:00 committed by GitHub
parent 80fbc35f51
commit aefad9e1ab
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 30 additions and 24 deletions

View File

@ -66,10 +66,7 @@ def img_arr_to_b64(img_arr):
file = io.BytesIO()
img_pil.save(file, format="PNG")
img_bin = file.getvalue()
if hasattr(base64, "encodebytes"):
img_b64 = base64.encodebytes(img_bin)
else:
img_b64 = base64.encodestring(img_bin)
img_b64 = base64.encodebytes(img_bin)
return img_b64
@ -93,7 +90,8 @@ def get_label_id_map(json_dir: str):
for file_name in os.listdir(json_dir):
if file_name.endswith("json"):
json_path = os.path.join(json_dir, file_name)
data = json.load(open(json_path))
with open(json_path, encoding="utf-8") as file:
data = json.load(file)
for shape in data["shapes"]:
label_set.add(shape["label"])
@ -102,13 +100,11 @@ def get_label_id_map(json_dir: str):
def extend_point_list(point_list, out_format="polygon"):
'''Extend point list to polygon or bbox'''
xmin = min([float(point) for point in point_list[::2]])
xmax = max([float(point) for point in point_list[::2]])
ymin = min([float(point) for point in point_list[1::2]])
ymax = max([float(point) for point in point_list[1::2]])
xmin = min(float(point) for point in point_list[::2])
xmax = max(float(point) for point in point_list[::2])
ymin = min(float(point) for point in point_list[1::2])
ymax = max(float(point) for point in point_list[1::2])
if out_format == "polygon":
return np.array([xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax])
if out_format == "bbox":
x = xmin
y = ymin
@ -118,6 +114,8 @@ def extend_point_list(point_list, out_format="polygon"):
y = y + h / 2
return np.array([x, y, w, h])
return np.array([xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax])
def save_yolo_label(json_name, label_dir_path, target_dir, yolo_obj_list):
'''Save yolo label to txt file'''
@ -125,7 +123,7 @@ def save_yolo_label(json_name, label_dir_path, target_dir, yolo_obj_list):
target_dir,
json_name.replace(".json", ".txt"))
with open(txt_path, "w+") as f:
with open(txt_path, "w+", encoding="utf-8") as f:
for yolo_obj in yolo_obj_list:
label, points = yolo_obj
points = [str(item) for item in points]
@ -154,19 +152,22 @@ def save_yolo_image(json_data, json_path, image_dir_path, target_dir):
return img_path
class Labelme2YOLO(object):
class Labelme2YOLO:
'''Labelme to YOLO format converter'''
def __init__(self, json_dir, output_format, label_list):
self._json_dir = json_dir
self._output_format = output_format
self._label_list = label_list
self._label_dir_path = ""
self._image_dir_path = ""
if label_list:
self._label_id_map = {label: label_id
for label_id, label in enumerate(label_list)}
else:
self._label_id_map = get_label_id_map(self._json_dir)
self._label_list = [label for label in self._label_id_map.keys()]
self._label_list = list(self._label_id_map.keys())
def _make_train_val_dir(self):
self._label_dir_path = os.path.join(self._json_dir,
@ -186,7 +187,10 @@ class Labelme2YOLO(object):
os.makedirs(yolo_path)
def _train_test_split(self, folders, json_names, val_size, test_size):
if len(folders) > 0 and 'train' in folders and 'val' in folders and 'test' in folders: # noqa: E501
if (len(folders) > 0 and
'train' in folders and
'val' in folders and
'test' in folders):
train_folder = os.path.join(self._json_dir, 'train/')
train_json_names = [train_sample_name + '.json'
for train_sample_name in os.listdir(train_folder)
@ -219,6 +223,7 @@ class Labelme2YOLO(object):
return train_json_names, val_json_names, test_json_names
def convert(self, val_size, test_size):
'''Convert labelme format to yolo format'''
json_names = [file_name for file_name in os.listdir(self._json_dir)
if os.path.isfile(os.path.join(self._json_dir, file_name)) and
file_name.endswith('.json')]
@ -233,23 +238,24 @@ class Labelme2YOLO(object):
# also get image from labelme json file and save them under images folder
for target_dir, json_names in zip(('train/', 'val/', 'test/'),
(train_json_names, val_json_names, test_json_names)): # noqa: E501
pool = Pool(NUM_THREADS)
for json_name in json_names:
pool.apply_async(self.covert_json_to_text,
args=(target_dir, json_name))
pool.close()
pool.join()
with Pool(NUM_THREADS) as pool:
for json_name in json_names:
pool.apply_async(self.covert_json_to_text,
args=(target_dir, json_name))
pool.close()
pool.join()
print('Generating dataset.yaml file ...')
self._save_dataset_yaml()
def covert_json_to_text(self, target_dir, json_name):
"""Convert json file to yolo format text file and save them to files"""
json_path = os.path.join(self._json_dir, json_name)
json_data = json.load(open(json_path))
with open(json_path, encoding="utf-8") as f:
json_data = json.load(f)
print('Converting %s for %s ...' %
(json_name, target_dir.replace('/', '')))
print(f"Converting {json_name} for {target_dir.replace('/', '')} ...")
img_path = save_yolo_image(json_data,
json_path,