From a0e0ff65f58a4653068e713d4993d84f7192666d Mon Sep 17 00:00:00 2001 From: Ariya Hidayat Date: Wed, 3 Apr 2024 09:33:34 -0700 Subject: [PATCH] Go --- .github/workflows/test-go.yml | 25 +++++++ README.md | 8 ++- ask-llm.go | 120 ++++++++++++++++++++++++++++++++++ 3 files changed, 150 insertions(+), 3 deletions(-) create mode 100644 .github/workflows/test-go.yml create mode 100644 ask-llm.go diff --git a/.github/workflows/test-go.yml b/.github/workflows/test-go.yml new file mode 100644 index 0000000..e34b870 --- /dev/null +++ b/.github/workflows/test-go.yml @@ -0,0 +1,25 @@ +name: Test with Go + +on: [push, pull_request] + +jobs: + test: + runs-on: ubuntu-22.04 + timeout-minutes: 10 + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-go@v5 + with: + go-version: '>=1.17.0' + + - run: go version + + - name: Prepare LLM (TinyLlama) + uses: ./.github/actions/prepare-llm + timeout-minutes: 3 + + - run: echo 'Which planet in our solar system is the largest?' | go run ./ask-llm.go | grep -i jupiter + timeout-minutes: 7 + env: + LLM_API_BASE_URL: 'http://127.0.0.1:8080/v1' + LLM_DEBUG: 1 \ No newline at end of file diff --git a/README.md b/README.md index 292de33..ab0c9a8 100644 --- a/README.md +++ b/README.md @@ -9,12 +9,14 @@ It is available in several flavors: * Python version. Compatible with [CPython](https://python.org) or [PyPy](https://pypy.org), v3.10 or higher. * JavaScript version. Compatible with [Node.js](https://nodejs.org) (>= v18) or [Bun](https://bun.sh) (>= v1.0). * Clojure version. Compatible with [Babashka](https://babashka.org/) (>= 1.3). +* Go version. Compatible with [Go](https://golang.org), v1.19 or higher. Once a suitable inference engine is set up (local or remote, read the next section), interact with the LLM: ```bash -./ask-llm.py # for Python user -./ask-llm.js # for Node.js user -./ask-llm.clj # for Clojure user +./ask-llm.py # for Python user +./ask-llm.js # for Node.js user +./ask-llm.clj # for Clojure user +go run ask-llm.go # for Go user ``` or pipe the question directly to get an immediate answer: diff --git a/ask-llm.go b/ask-llm.go new file mode 100644 index 0000000..beb7dd2 --- /dev/null +++ b/ask-llm.go @@ -0,0 +1,120 @@ +package main + +import ( + "bufio" + "bytes" + "encoding/json" + "fmt" + "net/http" + "os" + "time" +) + +var ( + LLMAPIBaseURL = os.Getenv("LLM_API_BASE_URL") + LLMAPIKey = os.Getenv("LLM_API_KEY") + LLMChatModel = os.Getenv("LLM_CHAT_MODEL") + LLMDebug = os.Getenv("LLM_DEBUG") +) + +type Message struct { + Role string `json:"role"` + Content string `json:"content"` +} + +type ChatRequest struct { + Messages []Message `json:"messages"` + Model string `json:"model"` + MaxTokens int `json:"max_tokens"` + Temperature float64 `json:"temperature"` +} + +type Choice struct { + Message struct { + Content string `json:"content"` + } `json:"message"` +} + +func chat(messages []Message) (string, error) { + url := fmt.Sprintf("%s/chat/completions", LLMAPIBaseURL) + authHeader := "" + if LLMAPIKey != "" { + authHeader = fmt.Sprintf("Bearer %s", LLMAPIKey) + } + requestBody := ChatRequest{ + Messages: messages, + Model: LLMChatModel, + MaxTokens: 200, + Temperature: 0, + } + jsonBody, err := json.Marshal(requestBody) + if err != nil { + return "", err + } + req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonBody)) + if err != nil { + return "", err + } + req.Header.Set("Content-Type", "application/json") + if authHeader != "" { + req.Header.Set("Authorization", authHeader) + } + + client := &http.Client{} + resp, err := client.Do(req) + if err != nil { + return "", err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return "", fmt.Errorf("HTTP error: %d %s", resp.StatusCode, resp.Status) + } + + var data struct { + Choices []Choice `json:"choices"` + } + if err := json.NewDecoder(resp.Body).Decode(&data); err != nil { + return "", err + } + + answer := data.Choices[0].Message.Content + return answer, nil +} + +const SystemPrompt = "Answer the question politely and concisely." + +func main() { + fmt.Printf("Using LLM at %s.\n", LLMAPIBaseURL) + fmt.Println("Press Ctrl+D to exit.") + fmt.Println() + + messages := []Message{{Role: "system", Content: SystemPrompt}} + + scanner := bufio.NewScanner(os.Stdin) + + for { + fmt.Print(">> ") + scanner.Scan() + question := scanner.Text() + + if question == "" { + break + } + + messages = append(messages, Message{Role: "user", Content: question}) + start := time.Now() + answer, err := chat(messages) + if err != nil { + fmt.Println("Error:", err) + break + } + messages = append(messages, Message{Role: "assistant", Content: answer}) + fmt.Println(answer) + elapsed := time.Since(start) + if LLMDebug != "" { + fmt.Printf("[%d ms]\n", elapsed.Milliseconds()) + } + fmt.Println() + } +}