feat(probe): add model discovery and canonical family normalization
This commit is contained in:
126
internal/probe/aliases.go
Normal file
126
internal/probe/aliases.go
Normal file
@@ -0,0 +1,126 @@
|
||||
package probe
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"unicode"
|
||||
)
|
||||
|
||||
type AliasResult struct {
|
||||
Raw string
|
||||
Normalized string
|
||||
Canonical string
|
||||
}
|
||||
|
||||
func NormalizeModelID(raw string) string {
|
||||
trimmed := strings.TrimSpace(strings.ToLower(raw))
|
||||
if trimmed == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
if idx := strings.LastIndex(trimmed, "/"); idx >= 0 {
|
||||
trimmed = trimmed[idx+1:]
|
||||
}
|
||||
|
||||
replacer := strings.NewReplacer("_", "-", " ", "-", ".", ".", "--", "-")
|
||||
normalized := replacer.Replace(trimmed)
|
||||
for strings.Contains(normalized, "--") {
|
||||
normalized = strings.ReplaceAll(normalized, "--", "-")
|
||||
}
|
||||
return strings.Trim(normalized, "-")
|
||||
}
|
||||
|
||||
func CanonicalModelID(raw string) string {
|
||||
return NormalizeModelID(raw)
|
||||
}
|
||||
|
||||
func CanonicalModelFamily(raw string) string {
|
||||
normalized := NormalizeModelID(raw)
|
||||
switch {
|
||||
case strings.HasPrefix(normalized, "kimi-k2."):
|
||||
return strings.Replace(normalized, "kimi-k2.", "kimi-2.", 1)
|
||||
case strings.HasPrefix(normalized, "kimi-k2-"):
|
||||
return strings.Replace(normalized, "kimi-k2-", "kimi-2-", 1)
|
||||
default:
|
||||
return normalized
|
||||
}
|
||||
}
|
||||
|
||||
func BuildAliasTable(rawModels []string) map[string]AliasResult {
|
||||
table := make(map[string]AliasResult, len(rawModels)*4)
|
||||
for _, rawModel := range rawModels {
|
||||
rawModel = strings.TrimSpace(rawModel)
|
||||
if rawModel == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
result := AliasResult{
|
||||
Raw: rawModel,
|
||||
Normalized: NormalizeModelID(rawModel),
|
||||
Canonical: CanonicalModelFamily(rawModel),
|
||||
}
|
||||
|
||||
keys := []string{
|
||||
rawModel,
|
||||
result.Normalized,
|
||||
result.Canonical,
|
||||
CanonicalModelID(rawModel),
|
||||
lookupKey(rawModel),
|
||||
lookupKey(result.Normalized),
|
||||
lookupKey(result.Canonical),
|
||||
lookupKey(CanonicalModelID(rawModel)),
|
||||
}
|
||||
for _, key := range keys {
|
||||
if key == "" {
|
||||
continue
|
||||
}
|
||||
if _, exists := table[key]; !exists {
|
||||
table[key] = result
|
||||
}
|
||||
}
|
||||
}
|
||||
return table
|
||||
}
|
||||
|
||||
func ResolveRequestedModel(requested string, rawModels []string) (resolved string, ok bool) {
|
||||
result, ok := BuildAliasTable(rawModels)[lookupKey(requested)]
|
||||
if !ok {
|
||||
return "", false
|
||||
}
|
||||
return result.Raw, true
|
||||
}
|
||||
|
||||
func RecommendModels(requested []string, rawModels []string) []string {
|
||||
table := BuildAliasTable(rawModels)
|
||||
recommended := make([]string, 0, len(requested))
|
||||
seen := make(map[string]struct{}, len(requested))
|
||||
|
||||
for _, requestedModel := range requested {
|
||||
result, ok := table[lookupKey(requestedModel)]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if _, exists := seen[result.Raw]; exists {
|
||||
continue
|
||||
}
|
||||
seen[result.Raw] = struct{}{}
|
||||
recommended = append(recommended, result.Raw)
|
||||
}
|
||||
|
||||
return recommended
|
||||
}
|
||||
|
||||
func lookupKey(raw string) string {
|
||||
canonical := CanonicalModelFamily(raw)
|
||||
if canonical == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
var builder strings.Builder
|
||||
builder.Grow(len(canonical))
|
||||
for _, r := range canonical {
|
||||
if unicode.IsLetter(r) || unicode.IsDigit(r) {
|
||||
builder.WriteRune(unicode.ToLower(r))
|
||||
}
|
||||
}
|
||||
return builder.String()
|
||||
}
|
||||
61
internal/probe/aliases_test.go
Normal file
61
internal/probe/aliases_test.go
Normal file
@@ -0,0 +1,61 @@
|
||||
package probe
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestCanonicalModelFamily(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("kimi aliases collapse into one family", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
variants := []string{"kimi 2.6", "kimi-2.6", "kimi-k2.6", "Kimi-K2.6"}
|
||||
for _, variant := range variants {
|
||||
if got := CanonicalModelFamily(variant); got != "kimi-2.6" {
|
||||
t.Fatalf("CanonicalModelFamily(%q) = %q, want %q", variant, got, "kimi-2.6")
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("deepseek vendor prefix normalizes away", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
if got := NormalizeModelID("deepseek-ai/DeepSeek-V4-Pro"); got != "deepseek-v4-pro" {
|
||||
t.Fatalf("NormalizeModelID() = %q, want %q", got, "deepseek-v4-pro")
|
||||
}
|
||||
if got := CanonicalModelID("deepseek-ai/DeepSeek-V4-Pro"); got != "deepseek-v4-pro" {
|
||||
t.Fatalf("CanonicalModelID() = %q, want %q", got, "deepseek-v4-pro")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("alias table and requested model resolution prefer discovered ids", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
rawModels := []string{"deepseek-ai/DeepSeek-V4-Pro", "kimi-k2.6"}
|
||||
table := BuildAliasTable(rawModels)
|
||||
if got := table["deepseek-v4-pro"].Canonical; got != "deepseek-v4-pro" {
|
||||
t.Fatalf("alias canonical = %q, want %q", got, "deepseek-v4-pro")
|
||||
}
|
||||
|
||||
resolved, ok := ResolveRequestedModel("DeepSeek V4 Pro", rawModels)
|
||||
if !ok {
|
||||
t.Fatal("ResolveRequestedModel() ok = false, want true")
|
||||
}
|
||||
if resolved != "deepseek-ai/DeepSeek-V4-Pro" {
|
||||
t.Fatalf("ResolveRequestedModel() = %q, want discovered raw id", resolved)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("recommend models returns canonical discovered candidates", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
rawModels := []string{"kimi-k2.6", "deepseek-ai/DeepSeek-V4-Pro"}
|
||||
got := RecommendModels([]string{"kimi 2.6", "DeepSeek V4 Pro", "unknown"}, rawModels)
|
||||
want := []string{"kimi-k2.6", "deepseek-ai/DeepSeek-V4-Pro"}
|
||||
if !reflect.DeepEqual(got, want) {
|
||||
t.Fatalf("RecommendModels() = %#v, want %#v", got, want)
|
||||
}
|
||||
})
|
||||
}
|
||||
97
internal/probe/models.go
Normal file
97
internal/probe/models.go
Normal file
@@ -0,0 +1,97 @@
|
||||
package probe
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
var ErrAuthFailed = errors.New("upstream auth failed")
|
||||
|
||||
type ModelsResult struct {
|
||||
RawModels []string
|
||||
HTTPStatus int
|
||||
LatencyMs int64
|
||||
Error string
|
||||
}
|
||||
|
||||
func ProviderModels(ctx context.Context, baseURL, apiKey string) (*ModelsResult, error) {
|
||||
client := &http.Client{Timeout: 15 * time.Second}
|
||||
|
||||
requestURL, err := joinGatewayPath(baseURL, "/v1/models")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("resolve models endpoint: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, requestURL, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("build models request: %w", err)
|
||||
}
|
||||
if token := strings.TrimSpace(apiKey); token != "" {
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
}
|
||||
|
||||
startedAt := time.Now()
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("request models: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
result := &ModelsResult{
|
||||
RawModels: []string{},
|
||||
HTTPStatus: resp.StatusCode,
|
||||
LatencyMs: time.Since(startedAt).Milliseconds(),
|
||||
}
|
||||
|
||||
var payload struct {
|
||||
Data []struct {
|
||||
ID string `json:"id"`
|
||||
} `json:"data"`
|
||||
Error any `json:"error"`
|
||||
}
|
||||
if err := json.NewDecoder(resp.Body).Decode(&payload); err != nil {
|
||||
return nil, fmt.Errorf("decode models response: %w", err)
|
||||
}
|
||||
|
||||
switch resp.StatusCode {
|
||||
case http.StatusUnauthorized, http.StatusForbidden:
|
||||
result.Error = "auth_failed"
|
||||
return result, ErrAuthFailed
|
||||
}
|
||||
|
||||
if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices {
|
||||
result.Error = fmt.Sprintf("unexpected_status_%d", resp.StatusCode)
|
||||
return result, fmt.Errorf("models endpoint returned status %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
for _, item := range payload.Data {
|
||||
if modelID := strings.TrimSpace(item.ID); modelID != "" {
|
||||
result.RawModels = append(result.RawModels, modelID)
|
||||
}
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func joinGatewayPath(baseURL, path string) (string, error) {
|
||||
parsedURL, err := url.Parse(strings.TrimSpace(baseURL))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if parsedURL.Scheme == "" || parsedURL.Host == "" {
|
||||
return "", fmt.Errorf("base url must include scheme and host")
|
||||
}
|
||||
|
||||
resolvedPath := strings.TrimSpace(path)
|
||||
if !strings.HasPrefix(resolvedPath, "/") {
|
||||
resolvedPath = "/" + resolvedPath
|
||||
}
|
||||
|
||||
return parsedURL.ResolveReference(&url.URL{Path: resolvedPath}).String(), nil
|
||||
}
|
||||
78
internal/probe/models_test.go
Normal file
78
internal/probe/models_test.go
Normal file
@@ -0,0 +1,78 @@
|
||||
package probe
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestProviderModels(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("parses openai models response", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/v1/models" {
|
||||
t.Fatalf("path = %q, want /v1/models", r.URL.Path)
|
||||
}
|
||||
if got := r.Header.Get("Authorization"); got != "Bearer sk-test" {
|
||||
t.Fatalf("authorization = %q, want bearer auth", got)
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{"data":[{"id":" kimi 2.6 "},{"id":"deepseek-ai/DeepSeek-V4-Pro"}]}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
result, err := ProviderModels(context.Background(), server.URL, "sk-test")
|
||||
if err != nil {
|
||||
t.Fatalf("ProviderModels() error = %v", err)
|
||||
}
|
||||
if result.HTTPStatus != http.StatusOK {
|
||||
t.Fatalf("HTTPStatus = %d, want %d", result.HTTPStatus, http.StatusOK)
|
||||
}
|
||||
if len(result.RawModels) != 2 {
|
||||
t.Fatalf("len(RawModels) = %d, want 2", len(result.RawModels))
|
||||
}
|
||||
if result.RawModels[0] != "kimi 2.6" || result.RawModels[1] != "deepseek-ai/DeepSeek-V4-Pro" {
|
||||
t.Fatalf("RawModels = %#v, want normalized trim order", result.RawModels)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("returns empty slice when upstream has no models", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{"data":[]}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
result, err := ProviderModels(context.Background(), server.URL, "sk-empty")
|
||||
if err != nil {
|
||||
t.Fatalf("ProviderModels() error = %v", err)
|
||||
}
|
||||
if result.HTTPStatus != http.StatusOK {
|
||||
t.Fatalf("HTTPStatus = %d, want %d", result.HTTPStatus, http.StatusOK)
|
||||
}
|
||||
if len(result.RawModels) != 0 {
|
||||
t.Fatalf("len(RawModels) = %d, want 0", len(result.RawModels))
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("classifies auth failure", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
http.Error(w, `{"error":"forbidden"}`, http.StatusForbidden)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
_, err := ProviderModels(context.Background(), server.URL, "sk-nope")
|
||||
if !errors.Is(err, ErrAuthFailed) {
|
||||
t.Fatalf("ProviderModels() error = %v, want ErrAuthFailed", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
Reference in New Issue
Block a user