Files
sub2api-cn-relay-manager/internal/probe/models.go

98 lines
2.4 KiB
Go

package probe
import (
"context"
"encoding/json"
"errors"
"fmt"
"net/http"
"net/url"
"strings"
"time"
)
var ErrAuthFailed = errors.New("upstream auth failed")
type ModelsResult struct {
RawModels []string
HTTPStatus int
LatencyMs int64
Error string
}
func ProviderModels(ctx context.Context, baseURL, apiKey string) (*ModelsResult, error) {
client := &http.Client{Timeout: 15 * time.Second}
requestURL, err := joinGatewayPath(baseURL, "/v1/models")
if err != nil {
return nil, fmt.Errorf("resolve models endpoint: %w", err)
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, requestURL, nil)
if err != nil {
return nil, fmt.Errorf("build models request: %w", err)
}
if token := strings.TrimSpace(apiKey); token != "" {
req.Header.Set("Authorization", "Bearer "+token)
}
startedAt := time.Now()
resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("request models: %w", err)
}
defer resp.Body.Close()
result := &ModelsResult{
RawModels: []string{},
HTTPStatus: resp.StatusCode,
LatencyMs: time.Since(startedAt).Milliseconds(),
}
var payload struct {
Data []struct {
ID string `json:"id"`
} `json:"data"`
Error any `json:"error"`
}
if err := json.NewDecoder(resp.Body).Decode(&payload); err != nil {
return nil, fmt.Errorf("decode models response: %w", err)
}
switch resp.StatusCode {
case http.StatusUnauthorized, http.StatusForbidden:
result.Error = "auth_failed"
return result, ErrAuthFailed
}
if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices {
result.Error = fmt.Sprintf("unexpected_status_%d", resp.StatusCode)
return result, fmt.Errorf("models endpoint returned status %d", resp.StatusCode)
}
for _, item := range payload.Data {
if modelID := strings.TrimSpace(item.ID); modelID != "" {
result.RawModels = append(result.RawModels, modelID)
}
}
return result, nil
}
func joinGatewayPath(baseURL, path string) (string, error) {
parsedURL, err := url.Parse(strings.TrimSpace(baseURL))
if err != nil {
return "", err
}
if parsedURL.Scheme == "" || parsedURL.Host == "" {
return "", fmt.Errorf("base url must include scheme and host")
}
resolvedPath := strings.TrimSpace(path)
if !strings.HasPrefix(resolvedPath, "/") {
resolvedPath = "/" + resolvedPath
}
return parsedURL.ResolveReference(&url.URL{Path: resolvedPath}).String(), nil
}