From 78a54f001358b53ab94352fa15d0fad05364bf7f Mon Sep 17 00:00:00 2001 From: j178 <10510431+j178@users.noreply.github.com> Date: Wed, 8 Nov 2023 13:35:56 +0800 Subject: [PATCH] fix: fix azure support --- chatgpt.go | 17 +++++++++-------- config.go | 1 - 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/chatgpt.go b/chatgpt.go index 1413cb7..0f7dedb 100644 --- a/chatgpt.go +++ b/chatgpt.go @@ -17,16 +17,17 @@ type ChatGPT struct { } func NewChatGPT(conf GlobalConfig) *ChatGPT { - config := openai.DefaultConfig(conf.APIKey) - config.OrgID = conf.OrgID - if conf.Endpoint != "" { - config.BaseURL = conf.Endpoint - } - if conf.APIType != openai.APITypeOpenAI { - config.APIType = conf.APIType + var config openai.ClientConfig + if conf.APIType == openai.APITypeOpenAI { + config = openai.DefaultConfig(conf.APIKey) + if conf.Endpoint != "" { + config.BaseURL = conf.Endpoint + } + } else { + config = openai.DefaultAzureConfig(conf.APIKey, conf.Endpoint) config.APIVersion = conf.APIVersion - config.Engine = conf.Engine } + config.OrgID = conf.OrgID client := openai.NewClientWithConfig(config) return &ChatGPT{globalConf: conf, client: client} } diff --git a/config.go b/config.go index 37f4d9a..c60ff2b 100644 --- a/config.go +++ b/config.go @@ -44,7 +44,6 @@ type GlobalConfig struct { Endpoint string `json:"endpoint"` APIType openai.APIType `json:"api_type,omitempty"` APIVersion string `json:"api_version,omitempty"` // required when APIType is APITypeAzure or APITypeAzureAD - Engine string `json:"engine,omitempty"` // required when APIType is APITypeAzure or APITypeAzureAD OrgID string `json:"org_id,omitempty"` Prompts map[string]string `json:"prompts"` Conversation ConversationConfig `json:"conversation"` // Default conversation config