scene graph generation benchmark关于visual genome的数据划分(train,test,val)

前言

很多做scene graph generation,准备测试的同学,发现visual genome并没有提供官方的训练train,测试test,验证val数据集划分。

不过我们可以参考scene graph generation benchmark的相关代码,然后找到这个代码里对VG150数据的划分,并参考这个领域对VG150的数据的划分,来划分VG150并进行验证。

我们进入scene graph generation benchmark的代码,在这个文件里打上断点。
Scene-Graph-Benchmark.pytorch/maskrcnn_benchmark/data/build.py

def build_dataset(cfg, dataset_list, transforms, dataset_catalog, is_train=True):
    """
    Arguments:
        dataset_list (list[str]): Contains the names of the datasets, i.e.,
            coco_2014_trian, coco_2014_val, etc
        transforms (callable): transforms to apply to each (image, target) sample
        dataset_catalog (DatasetCatalog): contains the information on how to
            construct a dataset.
        is_train (bool): whether to setup the dataset for training or testing
    """
    if not isinstance(dataset_list, (list, tuple)):
        raise RuntimeError(
            "dataset_list should be a list of strings, got {}".format(dataset_list)
        )
    datasets = []
    for dataset_name in dataset_list:
        data = dataset_catalog.get(dataset_name, cfg)
        factory = getattr(D, data["factory"])
        args = data["args"]
        # for COCODataset, we want to remove images without annotations
        # during training
        if data["factory"] == "COCODataset":
            args["remove_images_without_annotations"] = is_train
        if data["factory"] == "PascalVOCDataset":
            args["use_difficult"] = not is_train
        args["transforms"] = transforms

        #Remove it because not part of the original repo (factory cant deal with additional parameters...).
        if "capgraphs_file" in args.keys():
            del args["capgraphs_file"]

        # make dataset from factory
        dataset = factory(**args)
        datasets.append(dataset)

    # for testing, return a list of datasets
    if not is_train:
        return datasets

    # for training, concatenate all datasets into a single one
    dataset = datasets[0]
    if len(datasets) > 1:
        dataset = D.ConcatDataset(datasets)

    return [dataset]

这部分代码的大致意思是说:如果当前的dataset_name 是train(或test,或val),就通过项目中已包装好的参数,来读取相应的文件,处理成数据dataset,也就是我们主要在下面这两行代码上打断点。

dataset = factory(**args)
datasets.append(dataset)

然后我们根据dataset_name把三个相关的dataset用pickle头文件下载下来:(也就是添加这样一些代码:)

if dataset_name == 'VG_stanford_filtered_with_attribute_train':
    with open('data_train.pkl', 'wb') as f:
         pickle.dump(dataset, f)  
if dataset_name == 'VG_stanford_filtered_with_attribute_test':
    with open('data_test.pkl', 'wb') as f:
         pickle.dump(dataset, f)  
if dataset_name == 'VG_stanford_filtered_with_attribute_val':
    with open('data_val.pkl', 'wb') as f:     
         pickle.dump(dataset, f) 

通过不同的模式的运行,得到三个文件pkl:
在这里插入图片描述
但是这三个文件,因为数据格式跟maskrcnn设置有关,只有在scene graph generation benchamrak的项目里才能打开,所以我对这三个数据pkl文件做了如下处理

import os
import bisect
import copy
import logging
from collections import Counter, defaultdict
import matplotlib.pyplot as plt
import random
import numpy as np
import pickle

import json
from maskrcnn_benchmark.utils.comm import get_world_size
from maskrcnn_benchmark.utils.imports import import_file
from maskrcnn_benchmark.utils.miscellaneous import save_labels

import pickle  

file_paths = {
    'train': '/Scene-Graph-Benchmark.pytorch/data_train.pkl',
    'test': '/Scene-Graph-Benchmark.pytorch/data_test.pkl',
    'val': '/Scene-Graph-Benchmark.pytorch/data_val.pkl'
}

# 读取pkl文件并转换为字典格式
def read_pkl_file(file_path):
    with open(file_path, 'rb') as f:
        data = pickle.load(f)
    
    # 将数据转换为字典格式
    data_dict = {
        'categories': data.categories,
        'filenames': data.filenames,
        'gt_boxes': [box.tolist() for box in data.gt_boxes],  # 将numpy数组转换为列表
        'gt_classes': [cls.tolist() for cls in data.gt_classes],   # 将numpy数组转换为列表
        'img_info': data.img_info,
        'ind_to_attributes': data.ind_to_attributes,
        'ind_to_classes': data.ind_to_classes,
        'ind_to_predicates': data.ind_to_predicates,
        'relationships': [rel.tolist() for rel in data.relationships],  # 将numpy数组转换为列表
        'split': data.split
    }
    
    return data_dict

# 保存数据到json文件
def save_to_json(data_dict, output_path):
    with open(output_path, 'w') as json_file:
        json.dump(data_dict, json_file, indent=4)

# 处理每个文件并保存为json
for split, file_path in file_paths.items():
    data_dict = read_pkl_file(file_path)
    output_path = f'/Scene-Graph-Benchmark.pytorch/data_{split}.json'
    save_to_json(data_dict, output_path)
    print(f'Saved {split} data to {output_path}')

最终得到如下三个json文件(本地就可以运行):
在这里插入图片描述
运行起来是这样的:

import json
import os
import sys

file1_path = './data_train.json'
file2_path = './data_test.json'
file3_path = './data_val.json'

with open(file1_path, 'r', encoding='utf-8') as file:
    data1 = json.load(file)

with open(file2_path, 'r', encoding='utf-8') as file:
    data2 = json.load(file)

with open(file3_path, 'r', encoding='utf-8') as file:
    data3 = json.load(file)

print(data1)
print(data2)
print(data3)

在这里插入图片描述
其他的变量都很好理解,注意gt_classes是gt_boxes每一个物体的类别编号。
ind_to_classes是背景1+150个名词类别
ind_to_predicates是背景1+50个谓词类别
正好对应了vg150的150个名词和50个谓词

接下来大家愉快的使用这个SGG划分的VG150数据集吧。

链接:https://pan.baidu.com/s/1QGaNxPoWi9BIckYafoQrvg
提取码:yw1y
–来自百度网盘超级会员V8的分享

补充更新内容:

我又使用如下代码,把这些数据的数字标号转成了对应的英文三元组。

import json
import os
import sys

def triple_translation_json(file_path, output_file, recall_k='20'):
    # 从本地文件读取JSON数据
    with open(file_path, 'r', encoding='utf-8') as file:
        data = json.load(file)

    output_data = []

    # 对于每张图片
    for img_id,image_name_long in enumerate(data['filenames']):
        img_name = image_name_long.split('/')[-1]     # 或者 img_name = os.path.basename(img_info)
        # img_id 和img_name 就是我们需要用到的内容。
        # 获取图片相关的真实和预测三元组
        gt_unrevised_triplets = data['relationships'][img_id]
        # gt_unrevised_triplets[0]  = [7, 4, 31]

        # 读取关系和物体名称对应关系
        predicates = data['ind_to_predicates']
        ind_to_classes = data['ind_to_classes']

        # 将序号三元组替换为真实的英文名字
        def replace_with_names(triplets):
            replaced_triplets = []
            for triplet in triplets:
                replaced_triplet = [
                    ind_to_classes[triplet[0]],
                    predicates[triplet[1]],
                    ind_to_classes[triplet[2]]
                ]
                replaced_triplets.append(replaced_triplet)
            return replaced_triplets

        gt_temp_triplets = []
        for triplet in gt_unrevised_triplets:
            gt_triplet = [data['gt_classes'][img_id][triplet[0]],triplet[2],data['gt_classes'][img_id][triplet[1]]]
            gt_temp_triplets.append(gt_triplet)
        # 将真实和预测三元组转换为英文名字
        gt_triplets = replace_with_names(gt_temp_triplets)

        # 将结果添加到输出列表中
        output_data.append({
            'image_id': img_id,
            'image_path': img_name,
            'gt_triplets': gt_triplets,
        })

    # 将结果保存到新的JSON文件中
    with open(output_file, 'w', encoding='utf-8') as file:
        json.dump(output_data, file, ensure_ascii=False, indent=4)

if __name__ == "__main__":
    file1_path = './data_train.json'
    file2_path = './data_test.json'
    file3_path = './data_val.json'

    with open(file1_path, 'r', encoding='utf-8') as file:
        data1 = json.load(file)

    with open(file2_path, 'r', encoding='utf-8') as file:
        data2 = json.load(file)

    with open(file3_path, 'r', encoding='utf-8') as file:
        data3 = json.load(file)

    print(data1)
    print(data2)
    print(data3)

    output_file = './data_val_tiplets.json'
    triple_translation_json(file3_path,output_file)

在文件里就是这样:
在这里插入图片描述
生成的三个三元组文件:
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
已经上传了。
链接:https://pan.baidu.com/s/1LQdsDOfRKYv_QwZ1WKofvA
提取码:tp60
–来自百度网盘超级会员V8的分享

祝大家科研顺利,科研愉快

Logo

腾讯云面向开发者汇聚海量精品云计算使用和开发经验,营造开放的云计算技术生态圈。

更多推荐