基于文本语义的图像检索应用

近年来,随着大型语言模型(LLM)的发展,基于文本语义的图像检索技术取得了显著进步。这些模型通过理解复杂的自然语言描述,能够更准确地捕捉文本的语义,从而提高检索的精度和效率。

基于文本语义的图像检索是一种利用自然语言描述来搜索和获取相关图像的技术。实现这种检索的关键在于多模态学习技术,它能够处理和整合多种类型的数据,如文本和图像。通过将文本和图像数据嵌入到一个共享的表示空间,模型可以计算文本描述与图像之间的相似性,从而实现精确的检索。

相关资料:

https://huggingface.co/jinaai/jina-clip-v2
https://huggingface.co/jinaai/jina-clip-implementation
https://huggingface.co/datasets/jackyhate/text-to-image-2M

1. 图像存储

import glob
from PIL import Image
import sqlite3
from tqdm import tqdm
import os


def image_to_sqlite():
    fnames = glob.glob('images/*.jpg')
    print('图像数量:', len(fnames))
    print('图像尺寸:', Image.open(fnames[0]).size)
    # 连接到数据库(如果不存在则创建)
    conn = sqlite3.connect('images.db')
    cursor = conn.cursor()
    create_sql = 'CREATE TABLE IF NOT EXISTS images (id INTEGER PRIMARY KEY, fname TEXT)'
    cursor.execute(create_sql)

    progress = tqdm(range(len(fnames)), '图像入库')
    for fname in fnames:
        insert_sql = 'INSERT OR IGNORE INTO images (fname) VALUES ("{fname}")'
        cursor.execute(insert_sql)
        conn.commit()
        progress.update()
    progress.close()
    conn.close()


if __name__ == '__main__':
    image_to_sqlite()
图像数量: 8128
图像尺寸: (512, 512)
图像入库: 100%|████████████████████████████| 8128/8128 [00:45<00:00, 178.89it/s]

2. 图像编码

import warnings
warnings.filterwarnings('ignore')
import sqlite3
import faiss
import torch
from transformers import AutoModel
from torch.utils.data import DataLoader
from tqdm import tqdm


def image_to_faiss():
    conn = sqlite3.connect('images.db')
    cursor = conn.cursor()
    select_sql = 'SELECT * FROM images'
    cursor.execute(select_sql)
    images = cursor.fetchall()
    conn.close()

    def collate_fn(batch_images):
        images, ids = [], []
        for id, image in batch_images:
            images.append(image)
            ids.append(id)
        return images, ids

    estimator = AutoModel.from_pretrained('jina-clip-v2', trust_remote_code=True).cuda()
    dataloader = DataLoader(images, batch_size=64, collate_fn=collate_fn)
    progress = tqdm(range(len(dataloader)), '图像编码')
    index = faiss.IndexFlatIP(512)
    index = faiss.IndexIDMap(index)

    for images, ids in dataloader:
        with torch.no_grad():
            vectors = estimator.encode_image(images, truncate_dim=512)
        index.add_with_ids(vectors, ids)
        progress.update()
    progress.close()
    faiss.write_index(index, 'images.faiss')


if __name__ == '__main__':
    image_to_faiss()
图像编码: 100%|███████████████████████████████| 127/127 [08:21<00:00,  3.95s/it]

3. 图像检索

import warnings
warnings.filterwarnings('ignore')
import torch
import faiss
from transformers import AutoModel
import sqlite3
import matplotlib.pyplot as plt
from PIL import Image


def demo():
    index = faiss.read_index('images.faiss')
    conn = sqlite3.connect('images.db')
    cursor = conn.cursor()
    estimator = AutoModel.from_pretrained('jina-clip-v2', trust_remote_code=True).cuda()

    while True:
        query = input('请输入内容:')
        with torch.no_grad():
            vector = estimator.encode_text(query, task='retrieval.query', truncate_dim=512)
        vector = vector.reshape(1, -1)
        image_sim, image_ids = index.search(vector, k=2)
        image_ids = image_ids.squeeze().tolist()

        image_names = []
        for idx, image_id in enumerate(image_ids, start=1):
            query = f'SELECT fname FROM images WHERE id={image_id}'
            cursor.execute(query)
            (image_name,) = cursor.fetchone()
            plt.subplot(1, 2, idx)
            plt.imshow(Image.open(image_name))
            plt.axis('off')
        plt.show()


if __name__ == '__main__':
    demo()

未经允许不得转载:一亩三分地 » 基于文本语义的图像检索应用
评论 (0)

3 + 4 =