feat(probe): add model discovery and canonical family normalization

This commit is contained in:
phamnazage-jpg
2026-05-22 14:29:51 +08:00
parent 11182f2d4a
commit 2bc7554cf8
4 changed files with 362 additions and 0 deletions

126
internal/probe/aliases.go Normal file
View File

@@ -0,0 +1,126 @@
package probe
import (
"strings"
"unicode"
)
type AliasResult struct {
Raw string
Normalized string
Canonical string
}
func NormalizeModelID(raw string) string {
trimmed := strings.TrimSpace(strings.ToLower(raw))
if trimmed == "" {
return ""
}
if idx := strings.LastIndex(trimmed, "/"); idx >= 0 {
trimmed = trimmed[idx+1:]
}
replacer := strings.NewReplacer("_", "-", " ", "-", ".", ".", "--", "-")
normalized := replacer.Replace(trimmed)
for strings.Contains(normalized, "--") {
normalized = strings.ReplaceAll(normalized, "--", "-")
}
return strings.Trim(normalized, "-")
}
func CanonicalModelID(raw string) string {
return NormalizeModelID(raw)
}
func CanonicalModelFamily(raw string) string {
normalized := NormalizeModelID(raw)
switch {
case strings.HasPrefix(normalized, "kimi-k2."):
return strings.Replace(normalized, "kimi-k2.", "kimi-2.", 1)
case strings.HasPrefix(normalized, "kimi-k2-"):
return strings.Replace(normalized, "kimi-k2-", "kimi-2-", 1)
default:
return normalized
}
}
func BuildAliasTable(rawModels []string) map[string]AliasResult {
table := make(map[string]AliasResult, len(rawModels)*4)
for _, rawModel := range rawModels {
rawModel = strings.TrimSpace(rawModel)
if rawModel == "" {
continue
}
result := AliasResult{
Raw: rawModel,
Normalized: NormalizeModelID(rawModel),
Canonical: CanonicalModelFamily(rawModel),
}
keys := []string{
rawModel,
result.Normalized,
result.Canonical,
CanonicalModelID(rawModel),
lookupKey(rawModel),
lookupKey(result.Normalized),
lookupKey(result.Canonical),
lookupKey(CanonicalModelID(rawModel)),
}
for _, key := range keys {
if key == "" {
continue
}
if _, exists := table[key]; !exists {
table[key] = result
}
}
}
return table
}
func ResolveRequestedModel(requested string, rawModels []string) (resolved string, ok bool) {
result, ok := BuildAliasTable(rawModels)[lookupKey(requested)]
if !ok {
return "", false
}
return result.Raw, true
}
func RecommendModels(requested []string, rawModels []string) []string {
table := BuildAliasTable(rawModels)
recommended := make([]string, 0, len(requested))
seen := make(map[string]struct{}, len(requested))
for _, requestedModel := range requested {
result, ok := table[lookupKey(requestedModel)]
if !ok {
continue
}
if _, exists := seen[result.Raw]; exists {
continue
}
seen[result.Raw] = struct{}{}
recommended = append(recommended, result.Raw)
}
return recommended
}
func lookupKey(raw string) string {
canonical := CanonicalModelFamily(raw)
if canonical == "" {
return ""
}
var builder strings.Builder
builder.Grow(len(canonical))
for _, r := range canonical {
if unicode.IsLetter(r) || unicode.IsDigit(r) {
builder.WriteRune(unicode.ToLower(r))
}
}
return builder.String()
}

View File

@@ -0,0 +1,61 @@
package probe
import (
"reflect"
"testing"
)
func TestCanonicalModelFamily(t *testing.T) {
t.Parallel()
t.Run("kimi aliases collapse into one family", func(t *testing.T) {
t.Parallel()
variants := []string{"kimi 2.6", "kimi-2.6", "kimi-k2.6", "Kimi-K2.6"}
for _, variant := range variants {
if got := CanonicalModelFamily(variant); got != "kimi-2.6" {
t.Fatalf("CanonicalModelFamily(%q) = %q, want %q", variant, got, "kimi-2.6")
}
}
})
t.Run("deepseek vendor prefix normalizes away", func(t *testing.T) {
t.Parallel()
if got := NormalizeModelID("deepseek-ai/DeepSeek-V4-Pro"); got != "deepseek-v4-pro" {
t.Fatalf("NormalizeModelID() = %q, want %q", got, "deepseek-v4-pro")
}
if got := CanonicalModelID("deepseek-ai/DeepSeek-V4-Pro"); got != "deepseek-v4-pro" {
t.Fatalf("CanonicalModelID() = %q, want %q", got, "deepseek-v4-pro")
}
})
t.Run("alias table and requested model resolution prefer discovered ids", func(t *testing.T) {
t.Parallel()
rawModels := []string{"deepseek-ai/DeepSeek-V4-Pro", "kimi-k2.6"}
table := BuildAliasTable(rawModels)
if got := table["deepseek-v4-pro"].Canonical; got != "deepseek-v4-pro" {
t.Fatalf("alias canonical = %q, want %q", got, "deepseek-v4-pro")
}
resolved, ok := ResolveRequestedModel("DeepSeek V4 Pro", rawModels)
if !ok {
t.Fatal("ResolveRequestedModel() ok = false, want true")
}
if resolved != "deepseek-ai/DeepSeek-V4-Pro" {
t.Fatalf("ResolveRequestedModel() = %q, want discovered raw id", resolved)
}
})
t.Run("recommend models returns canonical discovered candidates", func(t *testing.T) {
t.Parallel()
rawModels := []string{"kimi-k2.6", "deepseek-ai/DeepSeek-V4-Pro"}
got := RecommendModels([]string{"kimi 2.6", "DeepSeek V4 Pro", "unknown"}, rawModels)
want := []string{"kimi-k2.6", "deepseek-ai/DeepSeek-V4-Pro"}
if !reflect.DeepEqual(got, want) {
t.Fatalf("RecommendModels() = %#v, want %#v", got, want)
}
})
}

97
internal/probe/models.go Normal file
View File

@@ -0,0 +1,97 @@
package probe
import (
"context"
"encoding/json"
"errors"
"fmt"
"net/http"
"net/url"
"strings"
"time"
)
var ErrAuthFailed = errors.New("upstream auth failed")
type ModelsResult struct {
RawModels []string
HTTPStatus int
LatencyMs int64
Error string
}
func ProviderModels(ctx context.Context, baseURL, apiKey string) (*ModelsResult, error) {
client := &http.Client{Timeout: 15 * time.Second}
requestURL, err := joinGatewayPath(baseURL, "/v1/models")
if err != nil {
return nil, fmt.Errorf("resolve models endpoint: %w", err)
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, requestURL, nil)
if err != nil {
return nil, fmt.Errorf("build models request: %w", err)
}
if token := strings.TrimSpace(apiKey); token != "" {
req.Header.Set("Authorization", "Bearer "+token)
}
startedAt := time.Now()
resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("request models: %w", err)
}
defer resp.Body.Close()
result := &ModelsResult{
RawModels: []string{},
HTTPStatus: resp.StatusCode,
LatencyMs: time.Since(startedAt).Milliseconds(),
}
var payload struct {
Data []struct {
ID string `json:"id"`
} `json:"data"`
Error any `json:"error"`
}
if err := json.NewDecoder(resp.Body).Decode(&payload); err != nil {
return nil, fmt.Errorf("decode models response: %w", err)
}
switch resp.StatusCode {
case http.StatusUnauthorized, http.StatusForbidden:
result.Error = "auth_failed"
return result, ErrAuthFailed
}
if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices {
result.Error = fmt.Sprintf("unexpected_status_%d", resp.StatusCode)
return result, fmt.Errorf("models endpoint returned status %d", resp.StatusCode)
}
for _, item := range payload.Data {
if modelID := strings.TrimSpace(item.ID); modelID != "" {
result.RawModels = append(result.RawModels, modelID)
}
}
return result, nil
}
func joinGatewayPath(baseURL, path string) (string, error) {
parsedURL, err := url.Parse(strings.TrimSpace(baseURL))
if err != nil {
return "", err
}
if parsedURL.Scheme == "" || parsedURL.Host == "" {
return "", fmt.Errorf("base url must include scheme and host")
}
resolvedPath := strings.TrimSpace(path)
if !strings.HasPrefix(resolvedPath, "/") {
resolvedPath = "/" + resolvedPath
}
return parsedURL.ResolveReference(&url.URL{Path: resolvedPath}).String(), nil
}

View File

@@ -0,0 +1,78 @@
package probe
import (
"context"
"errors"
"net/http"
"net/http/httptest"
"testing"
)
func TestProviderModels(t *testing.T) {
t.Parallel()
t.Run("parses openai models response", func(t *testing.T) {
t.Parallel()
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/v1/models" {
t.Fatalf("path = %q, want /v1/models", r.URL.Path)
}
if got := r.Header.Get("Authorization"); got != "Bearer sk-test" {
t.Fatalf("authorization = %q, want bearer auth", got)
}
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"data":[{"id":" kimi 2.6 "},{"id":"deepseek-ai/DeepSeek-V4-Pro"}]}`))
}))
defer server.Close()
result, err := ProviderModels(context.Background(), server.URL, "sk-test")
if err != nil {
t.Fatalf("ProviderModels() error = %v", err)
}
if result.HTTPStatus != http.StatusOK {
t.Fatalf("HTTPStatus = %d, want %d", result.HTTPStatus, http.StatusOK)
}
if len(result.RawModels) != 2 {
t.Fatalf("len(RawModels) = %d, want 2", len(result.RawModels))
}
if result.RawModels[0] != "kimi 2.6" || result.RawModels[1] != "deepseek-ai/DeepSeek-V4-Pro" {
t.Fatalf("RawModels = %#v, want normalized trim order", result.RawModels)
}
})
t.Run("returns empty slice when upstream has no models", func(t *testing.T) {
t.Parallel()
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"data":[]}`))
}))
defer server.Close()
result, err := ProviderModels(context.Background(), server.URL, "sk-empty")
if err != nil {
t.Fatalf("ProviderModels() error = %v", err)
}
if result.HTTPStatus != http.StatusOK {
t.Fatalf("HTTPStatus = %d, want %d", result.HTTPStatus, http.StatusOK)
}
if len(result.RawModels) != 0 {
t.Fatalf("len(RawModels) = %d, want 0", len(result.RawModels))
}
})
t.Run("classifies auth failure", func(t *testing.T) {
t.Parallel()
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
http.Error(w, `{"error":"forbidden"}`, http.StatusForbidden)
}))
defer server.Close()
_, err := ProviderModels(context.Background(), server.URL, "sk-nope")
if !errors.Is(err, ErrAuthFailed) {
t.Fatalf("ProviderModels() error = %v, want ErrAuthFailed", err)
}
})
}