Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions internal/controller/ai_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,10 @@ func (c *AIController) createOpenAIClient() *openai.Client {

aiProvider := aiConfig.GetProvider()

if aiProvider.Provider == "azure_ai" {
return c.createAzureAIClient(aiProvider)
}

config = openai.DefaultConfig(aiProvider.APIKey)
config.BaseURL = aiProvider.APIHost
if !strings.HasSuffix(config.BaseURL, "/v1") {
Expand All @@ -300,6 +304,36 @@ func (c *AIController) createOpenAIClient() *openai.Client {
return openai.NewClientWithConfig(config)
}

// createAzureAIClient creates an OpenAI client configured for Azure AI.
// Uses the Azure OpenAI compatibility endpoint: https://{resource}.openai.azure.com/openai/v1
func (c *AIController) createAzureAIClient(aiProvider *schema.SiteAIProvider) *openai.Client {
azureBaseURL := strings.TrimRight(aiProvider.APIHost, "/") + "/openai/v1"

config := openai.DefaultConfig(aiProvider.APIKey)
config.BaseURL = azureBaseURL
config.HTTPClient = &http.Client{
Transport: &azureAPIKeyTransport{
apiKey: aiProvider.APIKey,
transport: http.DefaultTransport,
},
}
return openai.NewClientWithConfig(config)
}

// azureAPIKeyTransport is an http.RoundTripper that replaces the Authorization
// header with the Azure-style api-key header for Azure OpenAI requests.
type azureAPIKeyTransport struct {
apiKey string
transport http.RoundTripper
}

func (t *azureAPIKeyTransport) RoundTrip(req *http.Request) (*http.Response, error) {
req = req.Clone(req.Context())
req.Header.Del("Authorization")
req.Header.Set("Api-Key", t.apiKey)
return t.transport.RoundTrip(req)
}

// getPromptByLanguage
func (c *AIController) getPromptByLanguage(language i18n.Language, question string) string {
aiConfig, err := c.siteInfoService.GetSiteAI(context.Background())
Expand Down Expand Up @@ -497,6 +531,10 @@ func (c *AIController) processAIStream(
break
}

if len(response.Choices) == 0 {
continue
}

choice := response.Choices[0]

if len(choice.Delta.ToolCalls) > 0 {
Expand Down
2 changes: 1 addition & 1 deletion internal/migrations/init_data.go
Original file line number Diff line number Diff line change
Expand Up @@ -353,7 +353,7 @@ var (
{ID: 128, Key: "rank.answer.undeleted", Value: `-1`},
{ID: 129, Key: "rank.question.undeleted", Value: `-1`},
{ID: 130, Key: "rank.tag.undeleted", Value: `-1`},
{ID: 131, Key: "ai_config.provider", Value: `[{"default_api_host":"https://api.openai.com","display_name":"OpenAI","name":"openai"},{"default_api_host":"https://generativelanguage.googleapis.com","display_name":"Gemini","name":"gemini"},{"default_api_host":"https://api.anthropic.com","display_name":"Anthropic","name":"anthropic"}]`},
{ID: 131, Key: "ai_config.provider", Value: `[{"default_api_host":"https://api.openai.com","display_name":"OpenAI","name":"openai"},{"default_api_host":"https://generativelanguage.googleapis.com","display_name":"Gemini","name":"gemini"},{"default_api_host":"https://api.anthropic.com","display_name":"Anthropic","name":"anthropic"},{"default_api_host":"https://{your-resource}.openai.azure.com","display_name":"Azure AI","name":"azure_ai"}]`},
}

defaultBadgeGroupTable = []*entity.BadgeGroup{
Expand Down
2 changes: 1 addition & 1 deletion internal/migrations/v31.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ func addAPIKey(ctx context.Context, x *xorm.Engine) error {
}

defaultConfigTable := []*entity.Config{
{ID: 131, Key: "ai_config.provider", Value: `[{"default_api_host":"https://api.openai.com","display_name":"OpenAI","name":"openai"},{"default_api_host":"https://generativelanguage.googleapis.com","display_name":"Gemini","name":"gemini"},{"default_api_host":"https://api.anthropic.com","display_name":"Anthropic","name":"anthropic"}]`},
{ID: 131, Key: "ai_config.provider", Value: `[{"default_api_host":"https://api.openai.com","display_name":"OpenAI","name":"openai"},{"default_api_host":"https://generativelanguage.googleapis.com","display_name":"Gemini","name":"gemini"},{"default_api_host":"https://api.anthropic.com","display_name":"Anthropic","name":"anthropic"},{"default_api_host":"https://{your-resource}.openai.azure.com","display_name":"Azure AI","name":"azure_ai"}]`},
}
for _, c := range defaultConfigTable {
exist, err := x.Context(ctx).Get(&entity.Config{Key: c.Key})
Expand Down
18 changes: 16 additions & 2 deletions internal/schema/ai_config_schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,9 @@ type GetAIModelsResp struct {
}

type GetAIModelsReq struct {
APIHost string `json:"api_host"`
APIKey string `json:"api_key"`
Provider string `json:"provider"`
APIHost string `json:"api_host"`
APIKey string `json:"api_key"`
}

// GetAIModelResp get AI model response
Expand All @@ -49,3 +50,16 @@ type GetAIModelResp struct {
Created int `json:"created"`
OwnedBy string `json:"owned_by"`
}

// GetAzureDeploymentsResp Azure OpenAI deployments response
type GetAzureDeploymentsResp struct {
Data []struct {
Id string `json:"id"`
Model string `json:"model"`
Owner string `json:"owner"`
Object string `json:"object"`
Status string `json:"status"`
CreatedAt int `json:"created_at"`
UpdatedAt int `json:"updated_at"`
} `json:"data"`
}
29 changes: 27 additions & 2 deletions internal/service/siteinfo/siteinfo_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -724,9 +724,20 @@ func (s *SiteInfoService) GetAIModels(ctx context.Context, req *schema.GetAIMode
}

r := resty.New()
r.SetHeader("Authorization", fmt.Sprintf("Bearer %s", req.APIKey))
r.SetHeader("Content-Type", "application/json")
respBody, err := r.R().Get(req.APIHost + "/v1/models")

var respBody *resty.Response
apiHost := strings.TrimRight(req.APIHost, "/")
if req.Provider == "azure_ai" {
// Azure AI: list deployments via the Azure OpenAI endpoint
r.SetHeader("api-key", req.APIKey)
deploymentsURL := apiHost + "/openai/deployments?api-version=2022-12-01"
respBody, err = r.R().Get(deploymentsURL)
} else {
// Standard OpenAI-compatible providers
r.SetHeader("Authorization", fmt.Sprintf("Bearer %s", req.APIKey))
respBody, err = r.R().Get(apiHost + "/v1/models")
}
if err != nil {
log.Error(err)
return resp, errors.BadRequest(fmt.Sprintf("failed to get AI models %s", err.Error()))
Expand All @@ -736,6 +747,20 @@ func (s *SiteInfoService) GetAIModels(ctx context.Context, req *schema.GetAIMode
return resp, errors.BadRequest(fmt.Sprintf("failed to get AI models, response: %s", respBody.String()))
}

if req.Provider == "azure_ai" {
data := schema.GetAzureDeploymentsResp{}
_ = json.Unmarshal(respBody.Body(), &data)
for _, d := range data.Data {
resp = append(resp, &schema.GetAIModelResp{
Id: d.Id,
Object: d.Object,
Created: d.CreatedAt,
OwnedBy: d.Model,
})
}
return resp, nil
}

data := schema.GetAIModelsResp{}
_ = json.Unmarshal(respBody.Body(), &data)

Expand Down
3 changes: 3 additions & 0 deletions ui/src/pages/Admin/AiSettings/index.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ const Index = () => {

const checkAiConfigData = (data) => {
const params = data || {
provider: formData.provider.value,
api_host: formData.api_host.value || apiHostPlaceholder,
api_key: formData.api_key.value,
};
Expand Down Expand Up @@ -151,6 +152,7 @@ const Index = () => {
const host = findHistoryProvider?.api_host || provider?.default_api_host;
if (findHistoryProvider?.model) {
checkAiConfigData({
provider: value,
api_host: host,
api_key: findHistoryProvider.api_key,
});
Expand Down Expand Up @@ -263,6 +265,7 @@ const Index = () => {
);
const host = currentAiConfig.api_host || provider?.default_api_host;
checkAiConfigData({
provider: currentAiConfig.provider,
api_host: host,
api_key: currentAiConfig.api_key,
});
Expand Down
Loading