137 lines
4.8 KiB
Python
137 lines
4.8 KiB
Python
|
|
from typing import Optional,Iterator
|
|
import os
|
|
import requests
|
|
import json
|
|
class Message:
|
|
def __init__(self,data):
|
|
self.role = data.get('role'),
|
|
self.content = data.get('content')
|
|
class Choice:
|
|
def __init__(self,data):
|
|
self.index = data.get('index')
|
|
self.message = Message(data.get('message',{}))
|
|
self.finish_reason = data.get('finish_reason')
|
|
class ChatCompletionResponse:
|
|
def __init__(self,data):
|
|
self.id = data.get('id')
|
|
self.object = data.get('object')
|
|
self.created = data.get('created')
|
|
self.model = data.get('model')
|
|
choices_data = data.get('choices',[])
|
|
self.choices = [
|
|
Choice(choice_data) for choice_data in choices_data
|
|
]
|
|
usage_data = data.get('usage',{})
|
|
self.usage = {
|
|
"prompt_tokens":usage_data.get("prompt_tokens"),
|
|
"completion_tokens":usage_data.get("completion_tokens"),
|
|
"total_tokens":usage_data.get("total_tokens"),
|
|
}
|
|
class DeltaMessage:
|
|
def __init__(self,data):
|
|
self.content = data.get('content')
|
|
self.role = data.get('role')
|
|
class DeltaChoice:
|
|
def __init__(self,data):
|
|
self.index = data.get('index')
|
|
self.delta = DeltaMessage(data.get('delta',{}) )
|
|
self.finish_reason = data.get('finish_reason')
|
|
#流式响应数据块,表示流式响应中的一个数据块
|
|
class StreamChunk:
|
|
def __init__(self,data):
|
|
self.id = data.get('id')
|
|
self.object = data.get('object')
|
|
self.created = data.get('created')
|
|
self.model = data.get('model')
|
|
choices_data = data.get('choices',[])
|
|
self.choices = [DeltaChoice(choice_data) for choice_data in choices_data]
|
|
|
|
class Stream:
|
|
def __init__(self,response:requests.Response):
|
|
self.response=response
|
|
def __enter__(self):
|
|
return self
|
|
def __exit__(self,exc_type,exc_val,exc_tb):
|
|
self.response.close()
|
|
def __iter__(self)->Iterator[StreamChunk]:
|
|
#迭代器方法,逐个返回流式数据块
|
|
try:
|
|
# 逐行读取响应的内容(SSE格式)
|
|
for line in self.response.iter_lines(decode_unicode=True):
|
|
#print(line)
|
|
if not line.strip():
|
|
continue
|
|
if line.startswith('data: '):
|
|
json_str = line[6:]
|
|
# 如果遇到[DONE]说明流式输出结束
|
|
if json_str.strip()=="[DONE]":
|
|
break
|
|
try:
|
|
data = json.loads(json_str)
|
|
yield StreamChunk(data)
|
|
except json.JSONDecodeError:
|
|
continue
|
|
|
|
finally:
|
|
self.response.close()
|
|
|
|
class ChatCompletions:
|
|
def __init__(self,client):
|
|
self._client = client
|
|
def create(self,model,messages,max_tokens=1024,temperature=0.7,stream:bool=False,**kwargs):
|
|
url = f"{self._client.base_url}/chat/completions"
|
|
body = {
|
|
"model":model,
|
|
"messages":messages
|
|
}
|
|
if max_tokens is not None:
|
|
body["max_tokens"]=max_tokens
|
|
if temperature is not None:
|
|
body["temperature"]=temperature
|
|
if stream:
|
|
body["stream"]=True
|
|
#添加额外的参数到请求体中
|
|
body.update(kwargs)
|
|
headers = {
|
|
"Authorization":f"Bearer {self._client.api_key}",
|
|
"Content-Type":"application/json"
|
|
}
|
|
if stream:
|
|
response = requests.post(
|
|
url,
|
|
headers=headers,
|
|
json=body,
|
|
timeout=self._client.timeout,
|
|
stream=True#告诉openai的服务器我要使用流式输出
|
|
)
|
|
response.raise_for_status()
|
|
return Stream(response)
|
|
else:
|
|
response = requests.post(
|
|
url,
|
|
headers=headers,
|
|
json=body,
|
|
timeout=self._client.timeout
|
|
)
|
|
# 如果响应的状态不是2XX的话,主抛异常
|
|
response.raise_for_status()
|
|
return ChatCompletionResponse(response.json())
|
|
|
|
class ChatResource:
|
|
def __init__(self,client):
|
|
self._client = client
|
|
@property
|
|
def completions(self)->ChatCompletions:
|
|
return ChatCompletions(self._client)
|
|
class OpenAI:
|
|
def __init__(self,api_key:Optional[str]=None,base_url:str="https://api.openai.com/v1",timeout:float=60.0):
|
|
self.api_key=api_key or os.getenv("OPENAI_API_KEY")
|
|
if not self.api_key:
|
|
raise ValueError(f"API密钥未设置,请设置api_key参数或者环境变量OPENAI_API_KEY")
|
|
self.base_url = base_url.rstrip('/')
|
|
self.timeout = timeout
|
|
@property
|
|
def chat(self)->ChatResource:
|
|
return ChatResource(self)
|