176 lines
5.6 KiB
Python
176 lines
5.6 KiB
Python
# 封装统一调用openai的客户端
|
|
from typing import Optional, Iterator
|
|
import os
|
|
|
|
import requests
|
|
import json
|
|
|
|
|
|
class Message:
|
|
def __init__(self, data):
|
|
self.content = data.get("content")
|
|
self.role = data.get("role")
|
|
|
|
|
|
class Choice:
|
|
def __init__(self, choice):
|
|
self.index = choice.get("index")
|
|
self.finish_reason = choice.get("finish_reason")
|
|
self.message = Message(choice.get("message", {}))
|
|
|
|
|
|
class ChatCompletionResponse:
|
|
def __init__(self, data) -> None:
|
|
self.id = data.get("id")
|
|
self.object = data.get("object")
|
|
self.created = data.get("created")
|
|
self.model = data.get("model")
|
|
choices_data = data.get("choices", [])
|
|
self.choices = [Choice(choice) for choice in choices_data]
|
|
usage_data = data.get("usage", {})
|
|
self.usage = {
|
|
"prompt_tokens": usage_data.get("prompt_tokens"),
|
|
"completion_tokens": usage_data.get("completion_tokens"),
|
|
"total_tokens": usage_data.get("total_tokens"),
|
|
}
|
|
|
|
|
|
class DeltaMessage:
|
|
def __init__(self, data) -> None:
|
|
self.content = data.get("content")
|
|
self.role = data.get("role")
|
|
|
|
|
|
class DeltaChoice:
|
|
def __init__(self, data):
|
|
self.index = data.get("index")
|
|
self.finish_reason = data.get("finish_reason")
|
|
self.delta = DeltaMessage(data.get("delta", {}))
|
|
|
|
|
|
class StreamChunk:
|
|
def __init__(self, data):
|
|
self.id = data.get("id")
|
|
self.object = data.get("object")
|
|
self.created = data.get("created")
|
|
self.model = data.get("model")
|
|
choices_data = data.get("choices", [])
|
|
self.choices = [DeltaChoice(choice) for choice in choices_data]
|
|
|
|
|
|
class Stream:
|
|
def __init__(self, response: requests.Response):
|
|
self.response = response
|
|
|
|
def __enter__(self):
|
|
return self
|
|
|
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
self.response.close()
|
|
|
|
def __iter__(self) -> Iterator[StreamChunk]:
|
|
# 迭代器方法,逐个返回流式数据块
|
|
try:
|
|
# 逐行读取响应的内容(SSE格式)
|
|
for line in self.response.iter_lines(decode_unicode=True):
|
|
# print(line)
|
|
# data: {"id":"3eddf823-6ee6-4b14-a231-b0fd9dbc8087","object":"chat.completion.chunk","created":1765355109,"model":"deepseek-chat","system_fingerprint":"fp_eaab8d114b_prod0820_fp8_kvcache","choices":[{"index":0,"delta":{"content":"观点"},"logprobs":null,"finish_reason":null}]}
|
|
if not line.strip():
|
|
continue
|
|
if line.startswith("data: "):
|
|
json_str = line[6:]
|
|
# 如果遇到DONE 结束,说明结束
|
|
if json_str.strip() == "[DONE]":
|
|
break
|
|
try:
|
|
data = json.loads(json_str)
|
|
yield StreamChunk(data)
|
|
except json.JSONDecodeError:
|
|
continue
|
|
finally:
|
|
self.response.close()
|
|
|
|
|
|
class ChatCompletions:
|
|
def __init__(self, client):
|
|
self._client = client
|
|
|
|
def create(
|
|
self,
|
|
model,
|
|
messages,
|
|
max_tokens=1024,
|
|
temperature=0.7,
|
|
stream: bool = False,
|
|
**kwargs,
|
|
):
|
|
url = f"{self._client.base_url}/chat/completions"
|
|
body = {
|
|
"model": model,
|
|
"messages": messages,
|
|
}
|
|
if max_tokens is not None:
|
|
body["max_tokens"] = max_tokens
|
|
if temperature is not None:
|
|
body["temperature"] = temperature
|
|
if stream:
|
|
body["stream"] = True
|
|
|
|
# 将其他参数添加到body中
|
|
body.update(kwargs)
|
|
headers = {
|
|
"Content-Type": "application/json",
|
|
"Accept": "application/json",
|
|
"Authorization": f"Bearer {self._client.api_key}",
|
|
}
|
|
if stream:
|
|
response = requests.post(
|
|
url,
|
|
headers=headers,
|
|
json=body,
|
|
timeout=self._client.timeout,
|
|
stream=True, # 告诉openai的服务器我要使用流式输出
|
|
)
|
|
response.raise_for_status()
|
|
return Stream(response)
|
|
else:
|
|
response = requests.post(
|
|
url, headers=headers, json=body, timeout=self._client.timeout
|
|
)
|
|
# 如果响应状态不是2xx则直接报错
|
|
# response.raise_for_status()是 requests库中一个非常重要的方法,用于自动检查 HTTP 响应状态码,并在状态码表示错误时抛出异常。
|
|
# 如果状态码是 2xx(成功):什么都不做,继续执行
|
|
# 如果状态码是 4xx 或 5xx(客户端或服务器错误):抛出异常
|
|
response.raise_for_status()
|
|
return ChatCompletionResponse(response.json())
|
|
|
|
|
|
class ChatResource:
|
|
def __init__(self, client):
|
|
self.client = client
|
|
|
|
@property
|
|
def completions(self):
|
|
return ChatCompletions(self.client)
|
|
|
|
|
|
class OpenAI:
|
|
def __init__(
|
|
self,
|
|
base_url: str = "https://api.deepseek.com/v1",
|
|
api_key: Optional[str] = None,
|
|
timeout: float = 60.0,
|
|
):
|
|
self.api_key = api_key or os.getenv("OPENAI_API_KEY")
|
|
if not self.api_key:
|
|
raise ValueError(
|
|
f"API秘钥未设置,请设置api_key参数或设置环境变量OPENAI_API_KEY"
|
|
)
|
|
self.base_url = base_url.rstrip("/")
|
|
self.timeout = timeout
|
|
|
|
# 可以使用属性.的方式使用方法
|
|
@property
|
|
def chat(self):
|
|
return ChatResource(self)
|