跳至主要内容

添加 Rerank 提供商

对于所有 rerank 提供商,LiteLLM 遵循 Cohere Rerank API 格式。以下是如何添加新的 rerank 提供商:

1. 创建 transformation.py 文件

创建一个名为 <Provider><Endpoint>Config 的配置类,该类继承自 BaseRerankConfig

from litellm.types.rerank import OptionalRerankParams, RerankRequest, RerankResponse
class YourProviderRerankConfig(BaseRerankConfig):
def get_supported_cohere_rerank_params(self, model: str) -> list:
return [
"query",
"documents",
"top_n",
# ... other supported params
]

def transform_rerank_request(self, model: str, optional_rerank_params: OptionalRerankParams, headers: dict) -> dict:
# Transform request to RerankRequest spec
return rerank_request.model_dump(exclude_none=True)

def transform_rerank_response(self, model: str, raw_response: httpx.Response, ...) -> RerankResponse:
# Transform provider response to RerankResponse
return RerankResponse(**raw_response_json)

2. 注册您的提供商

将您的提供商添加到 litellm.utils.get_provider_rerank_config()

elif litellm.LlmProviders.YOUR_PROVIDER == provider:
return litellm.YourProviderRerankConfig()

3. 将提供商添加到 rerank_api/main.py

添加代码块以处理调用您的提供商的情况。您的提供商应使用 base_llm_http_handler.rerank 方法

elif _custom_llm_provider == "your_provider":
...
response = base_llm_http_handler.rerank(
model=model,
custom_llm_provider=_custom_llm_provider,
optional_rerank_params=optional_rerank_params,
logging_obj=litellm_logging_obj,
timeout=optional_params.timeout,
api_key=dynamic_api_key or optional_params.api_key,
api_base=api_base,
_is_async=_is_async,
headers=headers or litellm.headers or {},
client=client,
mod el_response=model_response,
)
...

4. 添加测试

tests/llm_translation 中添加一个测试文件

def test_basic_rerank_cohere():
response = litellm.rerank(
model="cohere/rerank-english-v3.0",
query="hello",
documents=["hello", "world"],
top_n=3,
)

print("re rank response: ", response)

assert response.id is not None
assert response.results is not None

参考 PRs