summaryrefslogtreecommitdiff
path: root/doConnect.go
blob: 03e9dc999b7acdf53bc4bebe4c0185854f48b5dc (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
package main

import (
	"context"
	"fmt"
	"os"

	"go.wit.com/lib/protobuf/chatpb"
	"go.wit.com/log"
	"google.golang.org/genai"
)

func initGeminiAPI() error {
	if me.ctx != nil {
		// already initialized
		return nil
	}
	apiKey := os.Getenv("GEMINI_API_KEY")
	if apiKey == "" {
		return log.Errorf("GEMINI_API_KEY environment variable not set")
	}

	me.ctx = context.Background()
	var err error
	me.client, err = genai.NewClient(me.ctx, &genai.ClientConfig{APIKey: apiKey})
	if err != nil {
		return log.Errorf("failed to create new genai client: %w", err)
	}
	return nil
}

// doConnect initializes the Gemini client and handles the request flow.
func doConnect() error {
	initGeminiAPI()

	if me.lastChat == nil {
		log.Info("WTF. lastChat is nil")
		return nil
	}

	//	if me.lastChat.Entries == nil {
	//		me.lastChat.Entries = new(chatpb.ChatEntry)
	//	}

	// In a real application, you would get user input here.
	// For now, we'll use a hardcoded prompt.
	if len(me.lastChat.GetEntries()) == 0 {
		me.lastChat.Entries = append(me.lastChat.Entries, &chatpb.ChatEntry{
			Parts: []*chatpb.Part{
				{PartType: &chatpb.Part_Text{Text: "hello, how are you"}},
			},
		})
	}

	lastEntry := me.lastChat.GetEntries()[len(me.lastChat.GetEntries())-1]
	genaiContents, err := convertToGenai(lastEntry.GetGeminiRequest())
	if err != nil {
		return err
	}

	resp, err := me.client.Models.GenerateContent(me.ctx, "gemini-2.5-flash", genaiContents, nil)
	if err != nil {
		return log.Errorf("error sending message: %v", err)
	}

	if resp == nil || len(resp.Candidates) == 0 || resp.Candidates[0].Content == nil {
		log.Info("Received an empty response from the API. Stopping.")
		return nil
	}

	// Append the model's response to the history
	me.lastChat.Entries = append(me.lastChat.Entries, convertToPB(resp))

	// Check for a function call
	hasFunctionCall := false
	for _, part := range resp.Candidates[0].Content.Parts {
		if fc := part.FunctionCall; fc != nil {
			hasFunctionCall = true
			functionResponse := handleFunctionCall(fc)
			// Append the function response to the history for the next turn
			me.lastChat.Entries = append(me.lastChat.Entries, &chatpb.ChatEntry{
				Parts: []*chatpb.Part{
					{PartType: &chatpb.Part_FunctionResponse{
						FunctionResponse: &chatpb.FunctionResponse{
							Name: functionResponse.Name,
							// TODO: map response
						},
					}},
				},
			})
		}
	}

	// If there was no function call, print the text and stop.
	if !hasFunctionCall {
		log.Info("Response from API:")
		for _, cand := range resp.Candidates {
			if cand.Content != nil {
				for _, part := range cand.Content.Parts {
					if part.Text != "" {
						fmt.Println(part.Text)
					}
				}
			}
		}
	}

	return nil
}

// sampleHello sends a hardcoded prompt to the model and prints the response.
func simpleHello() error {
	log.Info("Sending 'hello, how are you' to the Gemini API...")

	// Create the parts slice
	parts := []*genai.Part{
		{Text: "What is my brothers name?"},
	}

	content := []*genai.Content{{Parts: parts}}

	resp, err := me.client.Models.GenerateContent(me.ctx, "gemini-2.5-flash", content, nil)
	if err != nil {
		return log.Errorf("error sending message: %v", err)
	}

	log.Info("Response from API:")
	for _, cand := range resp.Candidates {
		if cand.Content != nil {
			for _, part := range cand.Content.Parts {
				fmt.Println(part)
			}
		}
	}
	return nil
}