diff --git a/internal/controller/ai_controller.go b/internal/controller/ai_controller.go index e020ed30e..41b0afa9a 100644 --- a/internal/controller/ai_controller.go +++ b/internal/controller/ai_controller.go @@ -292,6 +292,10 @@ func (c *AIController) createOpenAIClient() *openai.Client { aiProvider := aiConfig.GetProvider() + if aiProvider.Provider == "azure_ai" { + return c.createAzureAIClient(aiProvider) + } + config = openai.DefaultConfig(aiProvider.APIKey) config.BaseURL = aiProvider.APIHost if !strings.HasSuffix(config.BaseURL, "/v1") { @@ -300,6 +304,36 @@ func (c *AIController) createOpenAIClient() *openai.Client { return openai.NewClientWithConfig(config) } +// createAzureAIClient creates an OpenAI client configured for Azure AI. +// Uses the Azure OpenAI compatibility endpoint: https://{resource}.openai.azure.com/openai/v1 +func (c *AIController) createAzureAIClient(aiProvider *schema.SiteAIProvider) *openai.Client { + azureBaseURL := strings.TrimRight(aiProvider.APIHost, "/") + "/openai/v1" + + config := openai.DefaultConfig(aiProvider.APIKey) + config.BaseURL = azureBaseURL + config.HTTPClient = &http.Client{ + Transport: &azureAPIKeyTransport{ + apiKey: aiProvider.APIKey, + transport: http.DefaultTransport, + }, + } + return openai.NewClientWithConfig(config) +} + +// azureAPIKeyTransport is an http.RoundTripper that replaces the Authorization +// header with the Azure-style api-key header for Azure OpenAI requests. +type azureAPIKeyTransport struct { + apiKey string + transport http.RoundTripper +} + +func (t *azureAPIKeyTransport) RoundTrip(req *http.Request) (*http.Response, error) { + req = req.Clone(req.Context()) + req.Header.Del("Authorization") + req.Header.Set("Api-Key", t.apiKey) + return t.transport.RoundTrip(req) +} + // getPromptByLanguage func (c *AIController) getPromptByLanguage(language i18n.Language, question string) string { aiConfig, err := c.siteInfoService.GetSiteAI(context.Background()) @@ -497,6 +531,10 @@ func (c *AIController) processAIStream( break } + if len(response.Choices) == 0 { + continue + } + choice := response.Choices[0] if len(choice.Delta.ToolCalls) > 0 { diff --git a/internal/migrations/init_data.go b/internal/migrations/init_data.go index 5af41bbfc..1fd9c5222 100644 --- a/internal/migrations/init_data.go +++ b/internal/migrations/init_data.go @@ -353,7 +353,7 @@ var ( {ID: 128, Key: "rank.answer.undeleted", Value: `-1`}, {ID: 129, Key: "rank.question.undeleted", Value: `-1`}, {ID: 130, Key: "rank.tag.undeleted", Value: `-1`}, - {ID: 131, Key: "ai_config.provider", Value: `[{"default_api_host":"https://api.openai.com","display_name":"OpenAI","name":"openai"},{"default_api_host":"https://generativelanguage.googleapis.com","display_name":"Gemini","name":"gemini"},{"default_api_host":"https://api.anthropic.com","display_name":"Anthropic","name":"anthropic"}]`}, + {ID: 131, Key: "ai_config.provider", Value: `[{"default_api_host":"https://api.openai.com","display_name":"OpenAI","name":"openai"},{"default_api_host":"https://generativelanguage.googleapis.com","display_name":"Gemini","name":"gemini"},{"default_api_host":"https://api.anthropic.com","display_name":"Anthropic","name":"anthropic"},{"default_api_host":"https://{your-resource}.openai.azure.com","display_name":"Azure AI","name":"azure_ai"}]`}, } defaultBadgeGroupTable = []*entity.BadgeGroup{ diff --git a/internal/migrations/v31.go b/internal/migrations/v31.go index 428c5828a..d2be0c9f5 100644 --- a/internal/migrations/v31.go +++ b/internal/migrations/v31.go @@ -61,7 +61,7 @@ func addAPIKey(ctx context.Context, x *xorm.Engine) error { } defaultConfigTable := []*entity.Config{ - {ID: 131, Key: "ai_config.provider", Value: `[{"default_api_host":"https://api.openai.com","display_name":"OpenAI","name":"openai"},{"default_api_host":"https://generativelanguage.googleapis.com","display_name":"Gemini","name":"gemini"},{"default_api_host":"https://api.anthropic.com","display_name":"Anthropic","name":"anthropic"}]`}, + {ID: 131, Key: "ai_config.provider", Value: `[{"default_api_host":"https://api.openai.com","display_name":"OpenAI","name":"openai"},{"default_api_host":"https://generativelanguage.googleapis.com","display_name":"Gemini","name":"gemini"},{"default_api_host":"https://api.anthropic.com","display_name":"Anthropic","name":"anthropic"},{"default_api_host":"https://{your-resource}.openai.azure.com","display_name":"Azure AI","name":"azure_ai"}]`}, } for _, c := range defaultConfigTable { exist, err := x.Context(ctx).Get(&entity.Config{Key: c.Key}) diff --git a/internal/schema/ai_config_schema.go b/internal/schema/ai_config_schema.go index 6ac686343..8aba6d7da 100644 --- a/internal/schema/ai_config_schema.go +++ b/internal/schema/ai_config_schema.go @@ -38,8 +38,9 @@ type GetAIModelsResp struct { } type GetAIModelsReq struct { - APIHost string `json:"api_host"` - APIKey string `json:"api_key"` + Provider string `json:"provider"` + APIHost string `json:"api_host"` + APIKey string `json:"api_key"` } // GetAIModelResp get AI model response @@ -49,3 +50,16 @@ type GetAIModelResp struct { Created int `json:"created"` OwnedBy string `json:"owned_by"` } + +// GetAzureDeploymentsResp Azure OpenAI deployments response +type GetAzureDeploymentsResp struct { + Data []struct { + Id string `json:"id"` + Model string `json:"model"` + Owner string `json:"owner"` + Object string `json:"object"` + Status string `json:"status"` + CreatedAt int `json:"created_at"` + UpdatedAt int `json:"updated_at"` + } `json:"data"` +} diff --git a/internal/service/siteinfo/siteinfo_service.go b/internal/service/siteinfo/siteinfo_service.go index 1e25cbaa4..e80e50c5e 100644 --- a/internal/service/siteinfo/siteinfo_service.go +++ b/internal/service/siteinfo/siteinfo_service.go @@ -724,9 +724,20 @@ func (s *SiteInfoService) GetAIModels(ctx context.Context, req *schema.GetAIMode } r := resty.New() - r.SetHeader("Authorization", fmt.Sprintf("Bearer %s", req.APIKey)) r.SetHeader("Content-Type", "application/json") - respBody, err := r.R().Get(req.APIHost + "/v1/models") + + var respBody *resty.Response + apiHost := strings.TrimRight(req.APIHost, "/") + if req.Provider == "azure_ai" { + // Azure AI: list deployments via the Azure OpenAI endpoint + r.SetHeader("api-key", req.APIKey) + deploymentsURL := apiHost + "/openai/deployments?api-version=2022-12-01" + respBody, err = r.R().Get(deploymentsURL) + } else { + // Standard OpenAI-compatible providers + r.SetHeader("Authorization", fmt.Sprintf("Bearer %s", req.APIKey)) + respBody, err = r.R().Get(apiHost + "/v1/models") + } if err != nil { log.Error(err) return resp, errors.BadRequest(fmt.Sprintf("failed to get AI models %s", err.Error())) @@ -736,6 +747,20 @@ func (s *SiteInfoService) GetAIModels(ctx context.Context, req *schema.GetAIMode return resp, errors.BadRequest(fmt.Sprintf("failed to get AI models, response: %s", respBody.String())) } + if req.Provider == "azure_ai" { + data := schema.GetAzureDeploymentsResp{} + _ = json.Unmarshal(respBody.Body(), &data) + for _, d := range data.Data { + resp = append(resp, &schema.GetAIModelResp{ + Id: d.Id, + Object: d.Object, + Created: d.CreatedAt, + OwnedBy: d.Model, + }) + } + return resp, nil + } + data := schema.GetAIModelsResp{} _ = json.Unmarshal(respBody.Body(), &data) diff --git a/ui/src/pages/Admin/AiSettings/index.tsx b/ui/src/pages/Admin/AiSettings/index.tsx index 2270aa5c5..55c3c28e2 100644 --- a/ui/src/pages/Admin/AiSettings/index.tsx +++ b/ui/src/pages/Admin/AiSettings/index.tsx @@ -84,6 +84,7 @@ const Index = () => { const checkAiConfigData = (data) => { const params = data || { + provider: formData.provider.value, api_host: formData.api_host.value || apiHostPlaceholder, api_key: formData.api_key.value, }; @@ -151,6 +152,7 @@ const Index = () => { const host = findHistoryProvider?.api_host || provider?.default_api_host; if (findHistoryProvider?.model) { checkAiConfigData({ + provider: value, api_host: host, api_key: findHistoryProvider.api_key, }); @@ -263,6 +265,7 @@ const Index = () => { ); const host = currentAiConfig.api_host || provider?.default_api_host; checkAiConfigData({ + provider: currentAiConfig.provider, api_host: host, api_key: currentAiConfig.api_key, });