diff --git a/.env.example b/.env.example index ff7488f..592a105 100644 --- a/.env.example +++ b/.env.example @@ -4,3 +4,4 @@ LLM_KEY_BAILIAN=sk-248f811295914adcad837XXXXXXXXXXX # 阿里云百炼模型 LLM_KEY_DEEPSEEK=sk-a89d0ff9421a43fca5f0xxxxxxxxxxxx # Deepseek模型服务平台key SERVER_PORT = 7004 # 服务启动端口 +SERVER_ENABLE_CORS = False # 是否允许跨域 diff --git a/app/config/env.py b/app/config/env.py index b4b0a75..136d2ea 100644 --- a/app/config/env.py +++ b/app/config/env.py @@ -9,6 +9,7 @@ class EnvSettings(BaseSettings): llm_key_deepseek: str = Field(..., env="LLM_KEY_DEEPSEEK") server_port: int = Field(..., env="SERVER_PORT") + server_enable_cors: bool = Field(..., env="SERVER_ENABLE_CORS") class Config: env_file = ".env" diff --git a/app/main.py b/app/main.py index d21665d..622e4a0 100644 --- a/app/main.py +++ b/app/main.py @@ -6,6 +6,7 @@ from fastapi.openapi.docs import get_swagger_ui_oauth2_redirect_html, get_redoc_ from fastapi.responses import RedirectResponse from langchain.chat_models import init_chat_model from langserve import add_routes +from starlette.middleware.cors import CORSMiddleware from starlette.staticfiles import StaticFiles from app.config.env import env @@ -32,4 +33,12 @@ app = FastAPI( for add_route_func in routes: add_route_func(app) -print("/*---------------------------------------main-------------------------------------------*/") +if env.server_enable_cors: + app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], + expose_headers=["*"], + )