修改前
# get masks
for index in meta['object_list']:
mask_path = prefix + index['mask path']
with open(mask_path, "r") as f:
mask = json.load(f)
mask = np.array(mask[0])
masks.append(mask)
text_descriptions.append(index['caption'])
修改后
import json
import numpy as np
from multiprocessing import Pool
def load_mask(mask_path):
with open(mask_path, "r") as f:
mask = json.load(f)
mask = np.array(mask[0])
return mask
def load_data(index):
mask_path = prefix + index['mask path']
mask = load_mask(mask_path)
text_description = index['caption']
return mask, text_description
masks = []
text_descriptions = []
if __name__ == "__main__":
with Pool() as pool:
results = pool.map(load_data, meta['object_list'])
for mask, text_description in results:
masks.append(mask)
text_descriptions.append(text_description)
上面的可能出问题,也可以这样写
import json
import numpy as np
from concurrent.futures import ThreadPoolExecutor
def load_mask(mask_path):
with open(mask_path, "r") as f:
mask = json.load(f)
mask = np.array(mask[0])
return mask
def load_data(index):
mask_path = prefix + index['mask path']
mask = load_mask(mask_path)
text_description = index['caption']
return mask, text_description
masks = []
text_descriptions = []
with ThreadPoolExecutor() as executor:
results = executor.map(load_data, meta['object_list'])
for mask, text_description in results:
masks.append(mask)
text_descriptions.append(text_description)
本文由 Yonghui Wang 创作,采用
知识共享署名4.0
国际许可协议进行许可
本站文章除注明转载/出处外,均为本站原创或翻译,转载前请务必署名
最后编辑时间为:
Dec 19, 2024 12:06 pm