检索系统中的关键字召回

在做检索时,我们可以通过关键字召回可能的结果。这里就是两个简单的方法:

  1. 基于 Inverted Index 进行召回
  2. 基于 TF-IDF 进行召回

1. 基于 Inverted Index 进行召回

这个思路较为简单,比如:我们有 1 万个文档,通过 jieba 分词工具,并去除停用词。构建词到文档的映射。当输入一个 query sentence,我们将其词提取出来,通过这些词快速获得包含这些关键字的文档,并根据包含关键字的多少进行排序,返回即可。

实现代码:

import jieba
jieba.setLogLevel(0)
import random
import pandas as pd
import pickle
from collections import Counter
from data_select import select_questions
from data_select import select_and_show_question


# 构建问题倒排索引
def build_inverted_index():

    questions = select_questions()
    inverted_index = {}
    stopwords = [word.strip() for word in open('file/stopwords.txt')]
    for qid, question in questions:
        words = [word for word in jieba.lcut(question) if word not in stopwords]
        if len(words) == 0:
            print('分词失败问题:', question)
            continue
        # 构建索引
        for word in words:
            if word not in inverted_index:
                inverted_index[word] = [qid]
            else:
                inverted_index[word].append(qid)

    pickle.dump(inverted_index, open('finish/keyword/inverted_index/inverted_index.pkl', 'wb'))


# 通过倒排索引返回包含关键字的候选列表
def generate_candidate(query, inversed_index, topK):

    # 输入问题分词并停用词过滤
    query = jieba.lcut(query)
    stopwords = [word.strip() for word in open('file/stopwords.txt')]
    query_words = [word for word in query if word not in stopwords]
    print('输入分词:', query_words)

    # 存储包含关键词的候选问题列表
    candidate_questions = []
    # 获得关键词对应的所有问题
    for word in query_words:
        try:
            candidate_questions.extend(inversed_index[word])
        except:
            pass
    # 选择包含关键字最多的前 100 个问题
    candidate_questions = Counter(candidate_questions).most_common(topK)
    candidate_questions = [question for question, freq in candidate_questions]

    return candidate_questions


def test():

    # 读取倒排索引
    inverted_index = pickle.load(open('finish/keyword/inverted_index/inverted_index.pkl', 'rb'))
    query_string = '宝宝的妈妈嗓子疼有点发烧孩子就是发烧'
    print('输入问题:', query_string)
    ids = generate_candidate(query_string, inverted_index, topK=10)
    print(ids)
    print('-' * 50)
    select_and_show_question(ids)


if __name__ == '__main__':
    build_inverted_index()
    test()

2. 基于 TF-IDF 进行召回

这个思路同样比较简单,我们先通过语料库训练 TfidfVectorizer 模型。然后将所有的问题转换为 TF-IDF 向量存储到 faiss 或者 milvus 向量数据库中。当一个新的 query sentence 输入时,使用训练好的 TF-IDF 模型转换为 TF-IDF 向量,在数据库中使用 cosine similarity 进行相似度比较,返回 top K 作为召回结果。

训练代码:

import jieba
jieba.setLogLevel(0)
import jieba.analyse as analyse
import numpy as np
from sklearn.preprocessing import OneHotEncoder
import jieba.posseg as psg
import pandas as pd
import pickle
from sklearn.feature_extraction.text import TfidfVectorizer
import faiss
import torch
import torch.nn.functional as F
from data_select import select_all_questions
import re


def is_chinese_word(words):

    for word in words:
        if '\u4e00' <= word <= '\u9fff':
            continue
        else:
            return False

    return True


def cut_word(sentence):

    # n = ['n', 'nr', 'ns', 'nt', 'nl', 'nz', 'nsf', 's'] + ['v', 'vd', 'vn', 'vx'] + ['a', 'ad', 'al', 'an']

    # 粗粒度分词
    # words_with_pos = psg.cut(sentence)
    # question_words = [word for word, pos in words_with_pos if pos in p]

    # 抽取关键字
    # question_words = analyse.tfidf(sentence, allowPOS=p, topK=30)

    # 搜索引擎模式,尽可能的分出词
    question_words = jieba.lcut_for_search(sentence)
    question_words = [word for word in question_words if is_chinese_word(word)]

    # words = analyse.textrank(sentence, allowPOS=allow_pos)
    # print('同义词增强:', [synonyms.nearby(word) for word in words])

    return ' '.join(question_words)


def train_tfidf():

    questions = select_all_questions()
    questions_words = [cut_word(question) for qid, question in questions]
    max_features = 81920
    stopwords = [word.strip() for word in open('file/stopwords.txt')]
    estimator = TfidfVectorizer(max_features=max_features, stop_words=stopwords, ngram_range=(1, 2))
    estimator.fit(questions_words)

    print('特征数量:', len(estimator.get_feature_names_out()))
    print('特征内容:', estimator.get_feature_names_out()[:50])

    pickle.dump(estimator, open('finish/keyword/tfidf/tfidf.pkl', 'wb'))


if __name__ == '__main__':
    train_tfidf()

查询代码:

import faiss
import pickle
import numpy as np
import pandas as pd
import jieba
jieba.setLogLevel(0)
import jieba.analyse as analyse
from data_select import select_and_show_question
from data_select import select_and_show_solution
from data_select import select_questions
from keyword_tfidf_train import cut_word


def generate_tfidf_to_faiss():

    estimator = pickle.load(open('finish/keyword/tfidf/tfidf.pkl', 'rb'))
    questions = select_questions()
    questions_words = [(qid, cut_word(question)) for qid, question in questions]

    write_number = 0
    database = faiss.IndexIDMap(faiss.IndexFlatIP(81920))
    for qid, question in questions_words:
        try:  # 有些句子分出的关键词列表为空,此时跳过
            question = estimator.transform([question]).toarray().tolist()
            database.add_with_ids(np.array(question), [qid])
            write_number += 1
        except Exception as e:
            pass

    print('写入 TF-IDF 数量:', write_number)
    faiss.write_index(database, 'finish/keyword/tfidf/tfidf.faiss')


def test():

    estimator = pickle.load(open('finish/keyword/tfidf/tfidf.pkl', 'rb'))
    database = faiss.read_index('finish/keyword/tfidf/tfidf.faiss')

    # 输入问题
    # input_question = '宝宝的妈妈嗓子疼有点发烧孩子就是发烧'
    # input_question = '怀孕时乳房会有刺痛感吗'
    # input_question = '小孩发烧,吃点什么药啊?'
    # input_question = '染头发影响宝宝吃奶吗?'
    query_string = '吃点啥药能降血压啊?'
    print('输入问题:', query_string)
    query_words = [cut_word(query_string)]
    print('输入分词:', query_words)

    query_vector = estimator.transform(query_words).toarray()
    distances, ids = database.search(query_vector, 10)

    print(ids[0])
    print(distances[0].tolist())
    select_and_show_question(ids[0])
    print('-' * 100)
    select_and_show_solution(ids[0])

if __name__ == '__main__':
    # generate_tfidf_to_faiss()
    test()
未经允许不得转载:一亩三分地 » 检索系统中的关键字召回
评论 (0)

2 + 9 =