cleanup code (#28)
This commit is contained in:
parent
80fbc35f51
commit
aefad9e1ab
|
@ -66,10 +66,7 @@ def img_arr_to_b64(img_arr):
|
|||
file = io.BytesIO()
|
||||
img_pil.save(file, format="PNG")
|
||||
img_bin = file.getvalue()
|
||||
if hasattr(base64, "encodebytes"):
|
||||
img_b64 = base64.encodebytes(img_bin)
|
||||
else:
|
||||
img_b64 = base64.encodestring(img_bin)
|
||||
img_b64 = base64.encodebytes(img_bin)
|
||||
return img_b64
|
||||
|
||||
|
||||
|
@ -93,7 +90,8 @@ def get_label_id_map(json_dir: str):
|
|||
for file_name in os.listdir(json_dir):
|
||||
if file_name.endswith("json"):
|
||||
json_path = os.path.join(json_dir, file_name)
|
||||
data = json.load(open(json_path))
|
||||
with open(json_path, encoding="utf-8") as file:
|
||||
data = json.load(file)
|
||||
for shape in data["shapes"]:
|
||||
label_set.add(shape["label"])
|
||||
|
||||
|
@ -102,13 +100,11 @@ def get_label_id_map(json_dir: str):
|
|||
|
||||
def extend_point_list(point_list, out_format="polygon"):
|
||||
'''Extend point list to polygon or bbox'''
|
||||
xmin = min([float(point) for point in point_list[::2]])
|
||||
xmax = max([float(point) for point in point_list[::2]])
|
||||
ymin = min([float(point) for point in point_list[1::2]])
|
||||
ymax = max([float(point) for point in point_list[1::2]])
|
||||
xmin = min(float(point) for point in point_list[::2])
|
||||
xmax = max(float(point) for point in point_list[::2])
|
||||
ymin = min(float(point) for point in point_list[1::2])
|
||||
ymax = max(float(point) for point in point_list[1::2])
|
||||
|
||||
if out_format == "polygon":
|
||||
return np.array([xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax])
|
||||
if out_format == "bbox":
|
||||
x = xmin
|
||||
y = ymin
|
||||
|
@ -118,6 +114,8 @@ def extend_point_list(point_list, out_format="polygon"):
|
|||
y = y + h / 2
|
||||
return np.array([x, y, w, h])
|
||||
|
||||
return np.array([xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax])
|
||||
|
||||
|
||||
def save_yolo_label(json_name, label_dir_path, target_dir, yolo_obj_list):
|
||||
'''Save yolo label to txt file'''
|
||||
|
@ -125,7 +123,7 @@ def save_yolo_label(json_name, label_dir_path, target_dir, yolo_obj_list):
|
|||
target_dir,
|
||||
json_name.replace(".json", ".txt"))
|
||||
|
||||
with open(txt_path, "w+") as f:
|
||||
with open(txt_path, "w+", encoding="utf-8") as f:
|
||||
for yolo_obj in yolo_obj_list:
|
||||
label, points = yolo_obj
|
||||
points = [str(item) for item in points]
|
||||
|
@ -154,19 +152,22 @@ def save_yolo_image(json_data, json_path, image_dir_path, target_dir):
|
|||
return img_path
|
||||
|
||||
|
||||
class Labelme2YOLO(object):
|
||||
class Labelme2YOLO:
|
||||
'''Labelme to YOLO format converter'''
|
||||
|
||||
def __init__(self, json_dir, output_format, label_list):
|
||||
self._json_dir = json_dir
|
||||
self._output_format = output_format
|
||||
self._label_list = label_list
|
||||
self._label_dir_path = ""
|
||||
self._image_dir_path = ""
|
||||
|
||||
if label_list:
|
||||
self._label_id_map = {label: label_id
|
||||
for label_id, label in enumerate(label_list)}
|
||||
else:
|
||||
self._label_id_map = get_label_id_map(self._json_dir)
|
||||
self._label_list = [label for label in self._label_id_map.keys()]
|
||||
self._label_list = list(self._label_id_map.keys())
|
||||
|
||||
def _make_train_val_dir(self):
|
||||
self._label_dir_path = os.path.join(self._json_dir,
|
||||
|
@ -186,7 +187,10 @@ class Labelme2YOLO(object):
|
|||
os.makedirs(yolo_path)
|
||||
|
||||
def _train_test_split(self, folders, json_names, val_size, test_size):
|
||||
if len(folders) > 0 and 'train' in folders and 'val' in folders and 'test' in folders: # noqa: E501
|
||||
if (len(folders) > 0 and
|
||||
'train' in folders and
|
||||
'val' in folders and
|
||||
'test' in folders):
|
||||
train_folder = os.path.join(self._json_dir, 'train/')
|
||||
train_json_names = [train_sample_name + '.json'
|
||||
for train_sample_name in os.listdir(train_folder)
|
||||
|
@ -219,6 +223,7 @@ class Labelme2YOLO(object):
|
|||
return train_json_names, val_json_names, test_json_names
|
||||
|
||||
def convert(self, val_size, test_size):
|
||||
'''Convert labelme format to yolo format'''
|
||||
json_names = [file_name for file_name in os.listdir(self._json_dir)
|
||||
if os.path.isfile(os.path.join(self._json_dir, file_name)) and
|
||||
file_name.endswith('.json')]
|
||||
|
@ -233,23 +238,24 @@ class Labelme2YOLO(object):
|
|||
# also get image from labelme json file and save them under images folder
|
||||
for target_dir, json_names in zip(('train/', 'val/', 'test/'),
|
||||
(train_json_names, val_json_names, test_json_names)): # noqa: E501
|
||||
pool = Pool(NUM_THREADS)
|
||||
|
||||
for json_name in json_names:
|
||||
pool.apply_async(self.covert_json_to_text,
|
||||
args=(target_dir, json_name))
|
||||
pool.close()
|
||||
pool.join()
|
||||
with Pool(NUM_THREADS) as pool:
|
||||
for json_name in json_names:
|
||||
pool.apply_async(self.covert_json_to_text,
|
||||
args=(target_dir, json_name))
|
||||
pool.close()
|
||||
pool.join()
|
||||
|
||||
print('Generating dataset.yaml file ...')
|
||||
self._save_dataset_yaml()
|
||||
|
||||
def covert_json_to_text(self, target_dir, json_name):
|
||||
"""Convert json file to yolo format text file and save them to files"""
|
||||
json_path = os.path.join(self._json_dir, json_name)
|
||||
json_data = json.load(open(json_path))
|
||||
with open(json_path, encoding="utf-8") as f:
|
||||
json_data = json.load(f)
|
||||
|
||||
print('Converting %s for %s ...' %
|
||||
(json_name, target_dir.replace('/', '')))
|
||||
print(f"Converting {json_name} for {target_dir.replace('/', '')} ...")
|
||||
|
||||
img_path = save_yolo_image(json_data,
|
||||
json_path,
|
||||
|
|
Loading…
Reference in New Issue