98 lines
2.4 KiB
Go
98 lines
2.4 KiB
Go
package probe
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"net/http"
|
|
"net/url"
|
|
"strings"
|
|
"time"
|
|
)
|
|
|
|
var ErrAuthFailed = errors.New("upstream auth failed")
|
|
|
|
type ModelsResult struct {
|
|
RawModels []string
|
|
HTTPStatus int
|
|
LatencyMs int64
|
|
Error string
|
|
}
|
|
|
|
func ProviderModels(ctx context.Context, baseURL, apiKey string) (*ModelsResult, error) {
|
|
client := &http.Client{Timeout: 15 * time.Second}
|
|
|
|
requestURL, err := joinGatewayPath(baseURL, "/v1/models")
|
|
if err != nil {
|
|
return nil, fmt.Errorf("resolve models endpoint: %w", err)
|
|
}
|
|
|
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, requestURL, nil)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("build models request: %w", err)
|
|
}
|
|
if token := strings.TrimSpace(apiKey); token != "" {
|
|
req.Header.Set("Authorization", "Bearer "+token)
|
|
}
|
|
|
|
startedAt := time.Now()
|
|
resp, err := client.Do(req)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("request models: %w", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
result := &ModelsResult{
|
|
RawModels: []string{},
|
|
HTTPStatus: resp.StatusCode,
|
|
LatencyMs: time.Since(startedAt).Milliseconds(),
|
|
}
|
|
|
|
var payload struct {
|
|
Data []struct {
|
|
ID string `json:"id"`
|
|
} `json:"data"`
|
|
Error any `json:"error"`
|
|
}
|
|
if err := json.NewDecoder(resp.Body).Decode(&payload); err != nil {
|
|
return nil, fmt.Errorf("decode models response: %w", err)
|
|
}
|
|
|
|
switch resp.StatusCode {
|
|
case http.StatusUnauthorized, http.StatusForbidden:
|
|
result.Error = "auth_failed"
|
|
return result, ErrAuthFailed
|
|
}
|
|
|
|
if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices {
|
|
result.Error = fmt.Sprintf("unexpected_status_%d", resp.StatusCode)
|
|
return result, fmt.Errorf("models endpoint returned status %d", resp.StatusCode)
|
|
}
|
|
|
|
for _, item := range payload.Data {
|
|
if modelID := strings.TrimSpace(item.ID); modelID != "" {
|
|
result.RawModels = append(result.RawModels, modelID)
|
|
}
|
|
}
|
|
|
|
return result, nil
|
|
}
|
|
|
|
func joinGatewayPath(baseURL, path string) (string, error) {
|
|
parsedURL, err := url.Parse(strings.TrimSpace(baseURL))
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
if parsedURL.Scheme == "" || parsedURL.Host == "" {
|
|
return "", fmt.Errorf("base url must include scheme and host")
|
|
}
|
|
|
|
resolvedPath := strings.TrimSpace(path)
|
|
if !strings.HasPrefix(resolvedPath, "/") {
|
|
resolvedPath = "/" + resolvedPath
|
|
}
|
|
|
|
return parsedURL.ResolveReference(&url.URL{Path: resolvedPath}).String(), nil
|
|
}
|