跳到主内容

Argilla

Argilla 是一个协作式标注工具,适用于需要为其项目构建高质量数据集的 AI 工程师和领域专家。

入门指南

要将数据记录到 Argilla,首先需要部署 Argilla 服务器。如果您尚未部署 Argilla 服务器,请按照此处的说明操作。

接下来,您需要配置并创建 Argilla 数据集。

import argilla as rg

client = rg.Argilla(api_url="<api_url>", api_key="<api_key>")

settings = rg.Settings(
guidelines="These are some guidelines.",
fields=[
rg.ChatField(
name="user_input",
),
rg.TextField(
name="llm_output",
),
],
questions=[
rg.RatingQuestion(
name="rating",
values=[1, 2, 3, 4, 5, 6, 7],
),
],
)

dataset = rg.Dataset(
name="my_first_dataset",
settings=settings,
)

dataset.create()

有关进一步的配置,请参阅Argilla 文档

用法

import os
import litellm
from litellm import completion

# add env vars
os.environ["ARGILLA_API_KEY"]="argilla.apikey"
os.environ["ARGILLA_BASE_URL"]="http://localhost:6900"
os.environ["ARGILLA_DATASET_NAME"]="my_first_dataset"
os.environ["OPENAI_API_KEY"]="sk-proj-..."

litellm.callbacks = ["argilla"]

# add argilla transformation object
litellm.argilla_transformation_object = {
"user_input": "messages", # 👈 key= argilla field, value = either message (argilla.ChatField) | response (argilla.TextField)
"llm_output": "response"
}

## LLM CALL ##
response = completion(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "Hello, how are you?"}],
)

示例输出

向 Argilla 调用添加采样率

要仅记录部分 Argilla 调用,请将 ARGILLA_SAMPLING_RATE 添加到您的环境变量中。

ARGILLA_SAMPLING_RATE=0.1 # log 10% of calls to argilla