Chroma Hello World

Chroma is the open-source embedding database. Chroma makes it easy to build LLM apps by making knowledge, facts, and skills pluggable for LLMs.

https://github.com/chroma-core/chroma
https://docs.trychroma.com

pip install chromadb

1. In-Memory Client

使用 Chroma 的步骤如下:

  1. 初始化 Client 对象;
  2. 在 Client 对象中创建一个或者多个 Collection 容器用来存储 Embeddings、Documents、Metadatas 等信息,每一个 Collection 需要指定一个唯一的名字;
  3. 使用 Collection 的 query 方法实现从 Collection 中查找与 texts 或者 Embeddings 近似的项。query 函数可以根据需要指定返回的结果,也可以通过元数据或者文本内容筛选、过滤结果。
import random
import chromadb
import numpy as np

def test():

    # 初始化客户端
    chroma_client = chromadb.Client()

    # 创建集合
    collection = chroma_client.create_collection(name="collection_demo01")

    # 插入数据
    embeddings = np.random.random((100, 128)).tolist()
    documents = [f'第{index + 1}文本内容' for index in range(100)]
    metadatas = [{'source': ['book', 'other'][random.randint(0, 1)]} for _ in range(100)]
    ids = [f'id-{id}' for id in range(100)]
    collection.add(embeddings=embeddings,
                   documents=documents,
                   metadatas=metadatas, ids=ids)

    # 查询数据
    query_embeddings = np.random.random((1, 128)).tolist()
    # where 参数可以根据 metadatas 过滤查询结果
    results = collection.query(query_embeddings=query_embeddings,
                               n_results=3,
                               include=['documents', 'metadatas'],
                               where={'source': 'book'})
    print(results)


if __name__ == '__main__':
    test()
{'ids': [['id-21', 'id-22', 'id-25']], 'distances': None, 'metadatas': [[{'source': 'book'}, {'source': 'book'}, {'source': 'book'}]], 'embeddings': None, 'documents': [['第22文本内容', '第23文本内容', '第26文本内容']]}

下面了解下 Collection 对象的相关方法:

import random
import chromadb
import numpy as np

def test():

    # 初始化客户端
    client = chromadb.Client()
    # 创建集合
    collection1 = client.create_collection(name="collection_demo01")
    collection2 = client.create_collection(name="collection_demo02")
    # 查看集合
    print(client.list_collections())
    # 删除集合
    client.delete_collection(name='collection_demo01')
    # 集合插入数据
    embeddings = np.random.random((100, 128)).tolist()
    documents = [f'第{index + 1}文本内容' for index in range(100)]
    collection2.add(embeddings=embeddings, documents=documents, ids=[str(item+1) for item in range(100)])
    # 集合元素数量
    print(collection2.count())
    # 获得指定ID数据
    print(collection2.get(ids=['2', '4'], include=['documents']))


if __name__ == '__main__':
    test()

Collection 的 Query 方法默认是 L2 计算相似度,我们也可以修改其默认的计算方式(ip、cosine):

import chromadb
import numpy as np

def test():

    np.random.seed(0)
    client = chromadb.Client()
    # "l2", "ip, "or "cosine"
    collection = client.create_collection(name="demo", metadata={'hnsw:space': 'cosine'})

    embeddings = np.random.random((200, 3)).tolist()
    ids = [str(item + 1) for item in range(200)]
    collection.add(embeddings=embeddings, ids=ids)

    print(collection.query(query_embeddings=np.random.random((1, 3)).tolist(), n_results=2))

if __name__ == '__main__':
    test()

2. Persistent Client

PersistentClient 可以自动的存储和加载 Collection,只需在初始化对象时指定持久化目录。

import logging
import random
import chromadb
chromadb.logger.setLevel(logging.ERROR)
import numpy as np

def test():

    # 初始化客户端
    client = chromadb.PersistentClient(path='collections')

    # 创建集合
    collection1 = client.get_or_create_collection(name="collection_demo01")
    collection2 = client.get_or_create_collection(name="collection_demo02")

    # 集合插入数据
    embeddings = np.random.random((200, 3))
    documents = [f'第{index + 1}文本内容' for index in range(200)]
    ids = [str(item+1) for item in range(200)]

    # add 函数不会重复插入 id 相同的元素
    collection1.add(embeddings=embeddings[:100, :].tolist(), documents=documents[:100], ids=ids[:100])
    collection2.add(embeddings=embeddings[100:, :].tolist(), documents=documents[100:], ids=ids[100:])

    # 打印集合元素
    print(collection1.peek(2))
    print(collection2.peek(2))


if __name__ == '__main__':
    test()

3. Http Client

启动并设置数据存储目录:

chroma run --path server_collections

通过下面的代码连接 chromadb 服务器:

import chromadb
import numpy as np

def test():

    # 连接 chroma 服务器
    client = chromadb.HttpClient(host='localhost', port=8000)

    # 创建集合
    collection = client.get_or_create_collection(name="collection_demo")

    # 插入数据
    embeddings = np.random.random((200, 3)).tolist()
    ids = [str(item + 1) for item in range(200)]
    collection.add(embeddings=embeddings, ids=ids)

    # 打印数据
    print(collection.peek(2))

if __name__ == '__main__':
    test()
{'ids': ['1', '2'], 'embeddings': [[0.8672704696655273, 0.06121295690536499, 0.40484386682510376], [0.9951606392860413, 0.690493643283844, 0.8979604244232178]], 'metadatas': [None, None], 'documents': ['第1文本内容', '第2文本内容']}
未经允许不得转载:一亩三分地 » Chroma Hello World
评论 (0)

5 + 9 =