近年来,随着大型语言模型(LLM)的发展,基于文本语义的图像检索技术取得了显著进步。这些模型通过理解复杂的自然语言描述,能够更准确地捕捉文本的语义,从而提高检索的精度和效率。
基于文本语义的图像检索是一种利用自然语言描述来搜索和获取相关图像的技术。实现这种检索的关键在于多模态学习技术,它能够处理和整合多种类型的数据,如文本和图像。通过将文本和图像数据嵌入到一个共享的表示空间,模型可以计算文本描述与图像之间的相似性,从而实现精确的检索。

相关资料:
https://huggingface.co/jinaai/jina-clip-v2
https://huggingface.co/jinaai/jina-clip-implementation
https://huggingface.co/datasets/jackyhate/text-to-image-2M
1. 图像存储
import glob from PIL import Image import sqlite3 from tqdm import tqdm import os def image_to_sqlite(): fnames = glob.glob('images/*.jpg') print('图像数量:', len(fnames)) print('图像尺寸:', Image.open(fnames[0]).size) # 连接到数据库(如果不存在则创建) conn = sqlite3.connect('images.db') cursor = conn.cursor() create_sql = 'CREATE TABLE IF NOT EXISTS images (id INTEGER PRIMARY KEY, fname TEXT)' cursor.execute(create_sql) progress = tqdm(range(len(fnames)), '图像入库') for fname in fnames: insert_sql = 'INSERT OR IGNORE INTO images (fname) VALUES ("{fname}")' cursor.execute(insert_sql) conn.commit() progress.update() progress.close() conn.close() if __name__ == '__main__': image_to_sqlite()
图像数量: 8128 图像尺寸: (512, 512) 图像入库: 100%|████████████████████████████| 8128/8128 [00:45<00:00, 178.89it/s]
2. 图像编码
import warnings warnings.filterwarnings('ignore') import sqlite3 import faiss import torch from transformers import AutoModel from torch.utils.data import DataLoader from tqdm import tqdm def image_to_faiss(): conn = sqlite3.connect('images.db') cursor = conn.cursor() select_sql = 'SELECT * FROM images' cursor.execute(select_sql) images = cursor.fetchall() conn.close() def collate_fn(batch_images): images, ids = [], [] for id, image in batch_images: images.append(image) ids.append(id) return images, ids estimator = AutoModel.from_pretrained('jina-clip-v2', trust_remote_code=True).cuda() dataloader = DataLoader(images, batch_size=64, collate_fn=collate_fn) progress = tqdm(range(len(dataloader)), '图像编码') index = faiss.IndexFlatIP(512) index = faiss.IndexIDMap(index) for images, ids in dataloader: with torch.no_grad(): vectors = estimator.encode_image(images, truncate_dim=512) index.add_with_ids(vectors, ids) progress.update() progress.close() faiss.write_index(index, 'images.faiss') if __name__ == '__main__': image_to_faiss()
图像编码: 100%|███████████████████████████████| 127/127 [08:21<00:00, 3.95s/it]
3. 图像检索
import warnings warnings.filterwarnings('ignore') import torch import faiss from transformers import AutoModel import sqlite3 import matplotlib.pyplot as plt from PIL import Image def demo(): index = faiss.read_index('images.faiss') conn = sqlite3.connect('images.db') cursor = conn.cursor() estimator = AutoModel.from_pretrained('jina-clip-v2', trust_remote_code=True).cuda() while True: query = input('请输入内容:') with torch.no_grad(): vector = estimator.encode_text(query, task='retrieval.query', truncate_dim=512) vector = vector.reshape(1, -1) image_sim, image_ids = index.search(vector, k=2) image_ids = image_ids.squeeze().tolist() image_names = [] for idx, image_id in enumerate(image_ids, start=1): query = f'SELECT fname FROM images WHERE id={image_id}' cursor.execute(query) (image_name,) = cursor.fetchone() plt.subplot(1, 2, idx) plt.imshow(Image.open(image_name)) plt.axis('off') plt.show() if __name__ == '__main__': demo()