1237 lines
40 KiB
Go
1237 lines
40 KiB
Go
package sub2api
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"errors"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"strings"
|
|
"testing"
|
|
)
|
|
|
|
func TestHTTPErrorErrorMessage(t *testing.T) {
|
|
e := newHTTPError("POST", "/api/v1/admin/groups", http.StatusTeapot, []byte("short and stout"))
|
|
want := "sub2api POST /api/v1/admin/groups returned 418: short and stout"
|
|
if got := e.Error(); got != want {
|
|
t.Fatalf("HTTPError.Error() = %q, want %q", got, want)
|
|
}
|
|
}
|
|
|
|
func TestWithHTTPClientAndOptions(t *testing.T) {
|
|
customHTTP := &http.Client{Timeout: 123}
|
|
client, err := NewClient("http://localhost:8080",
|
|
WithHTTPClient(customHTTP),
|
|
WithAPIKey(" sk-abc "),
|
|
WithBearerToken(" tok-xyz "),
|
|
)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if client.httpClient != customHTTP {
|
|
t.Fatal("WithHTTPClient not applied")
|
|
}
|
|
if client.apiKey != "sk-abc" {
|
|
t.Fatalf("apiKey = %q, want %q", client.apiKey, "sk-abc")
|
|
}
|
|
if client.bearerToken != "tok-xyz" {
|
|
t.Fatalf("bearerToken = %q, want %q", client.bearerToken, "tok-xyz")
|
|
}
|
|
}
|
|
|
|
func TestNewClient_RejectsInvalidURLs(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
url string
|
|
}{
|
|
{"empty", ""},
|
|
{"no scheme", "localhost:8080"},
|
|
{"no host", "http://"},
|
|
{"garbage", "://foo"},
|
|
}
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
_, err := NewClient(tt.url)
|
|
if err == nil {
|
|
t.Fatalf("NewClient(%q) error = nil, want error", tt.url)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestResolvePath(t *testing.T) {
|
|
client, err := NewClient("http://host:9090")
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
tests := []struct {
|
|
path string
|
|
want string
|
|
}{
|
|
{"/v1/models", "http://host:9090/v1/models"},
|
|
{"v1/models", "http://host:9090/v1/models"},
|
|
{"/v1/models?key=val", "http://host:9090/v1/models?key=val"},
|
|
}
|
|
for _, tt := range tests {
|
|
t.Run(tt.path, func(t *testing.T) {
|
|
if got := client.resolvePath(tt.path); got != tt.want {
|
|
t.Fatalf("resolvePath(%q) = %q, want %q", tt.path, got, tt.want)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestApplyAuth(t *testing.T) {
|
|
t.Run("api key preferred", func(t *testing.T) {
|
|
c, _ := NewClient("http://h:8080", WithAPIKey("key1"), WithBearerToken("btok"))
|
|
req, _ := http.NewRequest("GET", "http://h:8080/path", nil)
|
|
c.applyAuth(req)
|
|
if h := req.Header.Get("x-api-key"); h != "key1" {
|
|
t.Fatalf("x-api-key = %q, want %q", h, "key1")
|
|
}
|
|
if h := req.Header.Get("Authorization"); h != "" {
|
|
t.Fatalf("Authorization should be empty, got %q", h)
|
|
}
|
|
})
|
|
t.Run("bearer token fallback", func(t *testing.T) {
|
|
c, _ := NewClient("http://h:8080", WithBearerToken("btok"))
|
|
req, _ := http.NewRequest("GET", "http://h:8080/path", nil)
|
|
c.applyAuth(req)
|
|
if h := req.Header.Get("Authorization"); h != "Bearer btok" {
|
|
t.Fatalf("Authorization = %q, want %q", h, "Bearer btok")
|
|
}
|
|
})
|
|
t.Run("no auth", func(t *testing.T) {
|
|
c, _ := NewClient("http://h:8080")
|
|
req, _ := http.NewRequest("GET", "http://h:8080/path", nil)
|
|
c.applyAuth(req)
|
|
if h := req.Header.Get("x-api-key"); h != "" {
|
|
t.Fatalf("x-api-key should be empty, got %q", h)
|
|
}
|
|
if h := req.Header.Get("Authorization"); h != "" {
|
|
t.Fatalf("Authorization should be empty, got %q", h)
|
|
}
|
|
})
|
|
}
|
|
|
|
func TestDecodeEnvelopeObject(t *testing.T) {
|
|
t.Run("standard envelope", func(t *testing.T) {
|
|
body := []byte(`{"data":{"id":"g1","name":"test"}}`)
|
|
var ref GroupRef
|
|
if err := decodeEnvelopeObject(body, &ref); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if ref.ID != "g1" || ref.Name != "test" {
|
|
t.Fatalf("got %+v, want {ID:g1 Name:test}", ref)
|
|
}
|
|
})
|
|
t.Run("flat response (no data wrapper)", func(t *testing.T) {
|
|
body := []byte(`{"id":"g2","name":"flat"}`)
|
|
var ref GroupRef
|
|
if err := decodeEnvelopeObject(body, &ref); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if ref.ID != "g2" || ref.Name != "flat" {
|
|
t.Fatalf("got %+v, want {ID:g2 Name:flat}", ref)
|
|
}
|
|
})
|
|
t.Run("data:null returns flat", func(t *testing.T) {
|
|
body := []byte(`{"data":null,"id":"g3"}`)
|
|
var ref GroupRef
|
|
if err := decodeEnvelopeObject(body, &ref); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if ref.ID != "g3" {
|
|
t.Fatalf("id = %q, want %q", ref.ID, "g3")
|
|
}
|
|
})
|
|
t.Run("invalid json returns error", func(t *testing.T) {
|
|
var ref GroupRef
|
|
if err := decodeEnvelopeObject([]byte(`not json`), &ref); err == nil {
|
|
t.Fatal("expected error")
|
|
}
|
|
})
|
|
}
|
|
|
|
func TestDecodeGatewayModelIDs(t *testing.T) {
|
|
t.Run("standard list", func(t *testing.T) {
|
|
ids := decodeGatewayModelIDs([]byte(`{"data":[{"id":"gpt-4"},{"id":" claude-3 "}]}`))
|
|
if len(ids) != 2 || ids[0] != "gpt-4" || ids[1] != "claude-3" {
|
|
t.Fatalf("got %v, want [gpt-4 claude-3]", ids)
|
|
}
|
|
})
|
|
t.Run("empty data", func(t *testing.T) {
|
|
if ids := decodeGatewayModelIDs([]byte(`{}`)); ids != nil {
|
|
t.Fatalf("expected nil, got %v", ids)
|
|
}
|
|
})
|
|
t.Run("invalid json", func(t *testing.T) {
|
|
if ids := decodeGatewayModelIDs([]byte(`not json`)); ids != nil {
|
|
t.Fatalf("expected nil, got %v", ids)
|
|
}
|
|
})
|
|
t.Run("empty array", func(t *testing.T) {
|
|
if ids := decodeGatewayModelIDs([]byte(`{"data":[]}`)); ids != nil {
|
|
t.Fatalf("expected nil, got %v", ids)
|
|
}
|
|
})
|
|
}
|
|
|
|
func TestFilterNamedResourcesByName(t *testing.T) {
|
|
resources := []NamedResource{
|
|
{Name: "group-a", ID: "g1"},
|
|
{Name: "group-b", ID: "g2"},
|
|
{Name: " group-a ", ID: "g3"},
|
|
}
|
|
t.Run("match", func(t *testing.T) {
|
|
got := filterNamedResourcesByName(resources, "group-a")
|
|
if len(got) != 2 || got[0].ID != "g1" || got[1].ID != "g3" {
|
|
t.Fatalf("got %+v, want 2 matches", got)
|
|
}
|
|
})
|
|
t.Run("no match", func(t *testing.T) {
|
|
if got := filterNamedResourcesByName(resources, "nonexistent"); len(got) != 0 {
|
|
t.Fatalf("expected 0, got %d", len(got))
|
|
}
|
|
})
|
|
t.Run("empty name returns all", func(t *testing.T) {
|
|
if got := filterNamedResourcesByName(resources, ""); len(got) != 3 {
|
|
t.Fatalf("expected 3, got %d", len(got))
|
|
}
|
|
})
|
|
}
|
|
|
|
func TestFilterNamedResourcesByPrefix(t *testing.T) {
|
|
resources := []NamedResource{
|
|
{Name: "deepseek-proxy", ID: "r1"},
|
|
{Name: "deepseek-us", ID: "r2"},
|
|
{Name: "claude-eu", ID: "r3"},
|
|
}
|
|
t.Run("prefix matches", func(t *testing.T) {
|
|
got := filterNamedResourcesByPrefix(resources, "deepseek")
|
|
if len(got) != 2 {
|
|
t.Fatalf("expected 2, got %d", len(got))
|
|
}
|
|
})
|
|
t.Run("no prefix match", func(t *testing.T) {
|
|
if got := filterNamedResourcesByPrefix(resources, "nope"); len(got) != 0 {
|
|
t.Fatalf("expected 0, got %d", len(got))
|
|
}
|
|
})
|
|
t.Run("empty prefix returns all", func(t *testing.T) {
|
|
if got := filterNamedResourcesByPrefix(resources, ""); len(got) != 3 {
|
|
t.Fatalf("expected 3, got %d", len(got))
|
|
}
|
|
})
|
|
}
|
|
|
|
func TestDecodeNamedResources(t *testing.T) {
|
|
t.Run("envelope", func(t *testing.T) {
|
|
resources, pages, err := decodeNamedResources([]byte(`{"data":[{"id":"r1","name":"n1"}]}`))
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if pages != 1 {
|
|
t.Fatalf("pages = %d, want 1", pages)
|
|
}
|
|
if len(resources) != 1 || resources[0].ID != "r1" {
|
|
t.Fatalf("got %+v", resources)
|
|
}
|
|
})
|
|
t.Run("numeric id", func(t *testing.T) {
|
|
resources, pages, err := decodeNamedResources([]byte(`{"data":{"items":[{"id":1,"name":"default"}],"pages":2}}`))
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if pages != 2 {
|
|
t.Fatalf("pages = %d, want 2", pages)
|
|
}
|
|
if len(resources) != 1 || resources[0].ID != "1" {
|
|
t.Fatalf("got %+v", resources)
|
|
}
|
|
})
|
|
t.Run("wrapper with items", func(t *testing.T) {
|
|
resources, pages, err := decodeNamedResources([]byte(`{"data":{"items":[{"id":"r2","name":"n2"}]}}`))
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if pages != 1 {
|
|
t.Fatalf("pages = %d, want 1", pages)
|
|
}
|
|
if len(resources) != 1 || resources[0].ID != "r2" {
|
|
t.Fatalf("got %+v", resources)
|
|
}
|
|
})
|
|
t.Run("invalid json", func(t *testing.T) {
|
|
_, _, err := decodeNamedResources([]byte(`not json`))
|
|
if err == nil {
|
|
t.Fatal("expected error")
|
|
}
|
|
})
|
|
}
|
|
|
|
func TestDecodeAccountRefs(t *testing.T) {
|
|
t.Run("envelope", func(t *testing.T) {
|
|
refs, err := decodeAccountRefs([]byte(`{"data":[{"id":"a1"}]}`))
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if len(refs) != 1 || refs[0].ID != "a1" {
|
|
t.Fatalf("got %+v", refs)
|
|
}
|
|
})
|
|
t.Run("numeric id", func(t *testing.T) {
|
|
refs, err := decodeAccountRefs([]byte(`{"data":{"items":[{"id":42}]}}`))
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if len(refs) != 1 || refs[0].ID != "42" {
|
|
t.Fatalf("got %+v", refs)
|
|
}
|
|
})
|
|
t.Run("wrapper with items", func(t *testing.T) {
|
|
refs, err := decodeAccountRefs([]byte(`{"data":{"items":[{"id":"a2"}]}}`))
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if len(refs) != 1 || refs[0].ID != "a2" {
|
|
t.Fatalf("got %+v", refs)
|
|
}
|
|
})
|
|
t.Run("batch results", func(t *testing.T) {
|
|
refs, err := decodeAccountRefs([]byte(`{"success":1,"failed":0,"results":[{"name":"k1","id":123,"success":true}]}`))
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if len(refs) != 1 || refs[0].ID != "123" || refs[0].Name != "k1" {
|
|
t.Fatalf("got %+v", refs)
|
|
}
|
|
})
|
|
t.Run("batch results ignores failed items", func(t *testing.T) {
|
|
refs, err := decodeAccountRefs([]byte(`{"success":1,"failed":1,"results":[{"name":"k1","id":123,"success":true},{"name":"k2","id":456,"success":false}]}`))
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if len(refs) != 1 || refs[0].ID != "123" {
|
|
t.Fatalf("got %+v", refs)
|
|
}
|
|
})
|
|
t.Run("data wrapped batch results", func(t *testing.T) {
|
|
refs, err := decodeAccountRefs([]byte(`{"code":0,"message":"success","data":{"failed":0,"results":[{"id":5,"name":"deepseek-01","success":true}],"success":1}}`))
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if len(refs) != 1 || refs[0].ID != "5" || refs[0].Name != "deepseek-01" {
|
|
t.Fatalf("got %+v", refs)
|
|
}
|
|
})
|
|
t.Run("invalid json", func(t *testing.T) {
|
|
_, err := decodeAccountRefs([]byte(`not json`))
|
|
if err == nil {
|
|
t.Fatal("expected error")
|
|
}
|
|
})
|
|
}
|
|
|
|
func TestDecodeAccountModels(t *testing.T) {
|
|
t.Run("envelope", func(t *testing.T) {
|
|
models, err := decodeAccountModels([]byte(`{"data":[{"id":"gpt4","display_name":"GPT-4","type":"chat"}]}`))
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if len(models) != 1 || models[0].ID != "gpt4" {
|
|
t.Fatalf("got %+v", models)
|
|
}
|
|
})
|
|
t.Run("wrapper with items", func(t *testing.T) {
|
|
models, err := decodeAccountModels([]byte(`{"data":{"items":[{"id":"cl3","display_name":"Claude 3","type":"chat"}]}}`))
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if len(models) != 1 || models[0].ID != "cl3" {
|
|
t.Fatalf("got %+v", models)
|
|
}
|
|
})
|
|
t.Run("invalid json", func(t *testing.T) {
|
|
_, err := decodeAccountModels([]byte(`not json`))
|
|
if err == nil {
|
|
t.Fatal("expected error")
|
|
}
|
|
})
|
|
}
|
|
|
|
func TestParseProbeResult(t *testing.T) {
|
|
t.Run("SSE with ok=true", func(t *testing.T) {
|
|
result, err := parseProbeResult([]byte("data: {\"status\":\"passed\",\"ok\":true}\n"))
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if !result.OK || result.Status != "passed" {
|
|
t.Fatalf("got %+v, want OK=true Status=passed", result)
|
|
}
|
|
})
|
|
t.Run("SSE with success=true", func(t *testing.T) {
|
|
result, err := parseProbeResult([]byte("data: {\"status\":\"succeeded\",\"success\":true}\n"))
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if !result.OK || result.Status != "passed" {
|
|
t.Fatalf("got %+v", result)
|
|
}
|
|
})
|
|
t.Run("SSE with ok=false", func(t *testing.T) {
|
|
result, err := parseProbeResult([]byte("data: {\"status\":\"failed\",\"ok\":false}\n"))
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if result.OK || result.Status != "failed" {
|
|
t.Fatalf("got %+v", result)
|
|
}
|
|
})
|
|
t.Run("SSE with status-based ok", func(t *testing.T) {
|
|
result, err := parseProbeResult([]byte("data: {\"status\":\"pass\",\"message\":\"all good\"}\n"))
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if !result.OK || result.Message != "all good" {
|
|
t.Fatalf("got %+v", result)
|
|
}
|
|
})
|
|
t.Run("multiple SSE events picks last", func(t *testing.T) {
|
|
result, err := parseProbeResult([]byte("data: {\"status\":\"running\"}\ndata: {\"status\":\"passed\",\"ok\":true}\n"))
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if !result.OK {
|
|
t.Fatalf("expected OK=true from last event, got %+v", result)
|
|
}
|
|
})
|
|
t.Run("no data events", func(t *testing.T) {
|
|
_, err := parseProbeResult([]byte("not data\n"))
|
|
if err == nil {
|
|
t.Fatal("expected error")
|
|
}
|
|
})
|
|
}
|
|
|
|
func TestNormalizeProbeStatus(t *testing.T) {
|
|
tests := []struct {
|
|
status string
|
|
ok bool
|
|
want string
|
|
}{
|
|
{"pass", true, "passed"},
|
|
{"PASSED", true, "passed"},
|
|
{"Ok", true, "passed"},
|
|
{"success", true, "passed"},
|
|
{"succeeded", true, "passed"},
|
|
{"fail", false, "failed"},
|
|
{"FAILED", false, "failed"},
|
|
{"error", false, "failed"},
|
|
{"custom_ok", true, "passed"},
|
|
{"custom_fail", false, "failed"},
|
|
}
|
|
for _, tt := range tests {
|
|
t.Run(tt.status, func(t *testing.T) {
|
|
if got := normalizeProbeStatus(tt.status, tt.ok); got != tt.want {
|
|
t.Fatalf("normalizeProbeStatus(%q, %v) = %q, want %q", tt.status, tt.ok, got, tt.want)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
func TestLooksLikeExistingEndpoint(t *testing.T) {
|
|
t.Run("json content type", func(t *testing.T) {
|
|
h := http.Header{"Content-Type": []string{"application/json"}}
|
|
if !looksLikeExistingEndpoint(h, nil) {
|
|
t.Fatal("expected true with json content type")
|
|
}
|
|
})
|
|
t.Run("sse content type", func(t *testing.T) {
|
|
h := http.Header{"Content-Type": []string{"text/event-stream"}}
|
|
if !looksLikeExistingEndpoint(h, nil) {
|
|
t.Fatal("expected true with sse content type")
|
|
}
|
|
})
|
|
t.Run("empty body and no content type", func(t *testing.T) {
|
|
if looksLikeExistingEndpoint(http.Header{}, nil) {
|
|
t.Fatal("expected false")
|
|
}
|
|
})
|
|
t.Run("json-like body", func(t *testing.T) {
|
|
if !looksLikeExistingEndpoint(http.Header{}, []byte(`{"error":"not found"}`)) {
|
|
t.Fatal("expected true for json body")
|
|
}
|
|
})
|
|
t.Run("array body", func(t *testing.T) {
|
|
if !looksLikeExistingEndpoint(http.Header{}, []byte(`[]`)) {
|
|
t.Fatal("expected true for array body")
|
|
}
|
|
})
|
|
t.Run("html body", func(t *testing.T) {
|
|
if looksLikeExistingEndpoint(http.Header{}, []byte(`<html>`)) {
|
|
t.Fatal("expected false for html body")
|
|
}
|
|
})
|
|
}
|
|
|
|
// Tests for NamedResource type used by the filter functions.
|
|
// Defined locally since it's in the same package.
|
|
|
|
func TestNewClientWithNilOption(t *testing.T) {
|
|
client, err := NewClient("http://localhost:8080", nil)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if client == nil {
|
|
t.Fatal("client is nil")
|
|
}
|
|
}
|
|
|
|
func TestNewHTTPError(t *testing.T) {
|
|
e := newHTTPError("GET", "/v1/models", 200, []byte(`{"ok":true}`))
|
|
if e.Method != "GET" || e.Path != "/v1/models" || e.StatusCode != 200 || e.Body != `{"ok":true}` {
|
|
t.Fatalf("unexpected http error: %+v", e)
|
|
}
|
|
}
|
|
|
|
func TestPerformWithMockServer(t *testing.T) {
|
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
switch r.URL.Path {
|
|
case "/api/v1/admin/system/version":
|
|
w.Write([]byte(`{"data":{"version":"v1.2.3"}}`))
|
|
case "/api/v1/admin/groups":
|
|
w.Write([]byte(`{"data":{"id":"g1","name":"test-group"}}`))
|
|
case "/api/v1/admin/channels":
|
|
w.WriteHeader(http.StatusInternalServerError)
|
|
w.Write([]byte(`{"error":"panic"}`))
|
|
default:
|
|
w.WriteHeader(http.StatusNotFound)
|
|
}
|
|
}))
|
|
defer srv.Close()
|
|
|
|
client, err := NewClient(srv.URL, WithAPIKey("test-key"))
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
t.Run("GetHostVersion", func(t *testing.T) {
|
|
ver, err := client.GetHostVersion(context.Background())
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if ver != "v1.2.3" {
|
|
t.Fatalf("version = %q, want %q", ver, "v1.2.3")
|
|
}
|
|
})
|
|
|
|
t.Run("postJSON success", func(t *testing.T) {
|
|
var ref GroupRef
|
|
if err := client.postJSON(context.Background(), "/api/v1/admin/groups", CreateGroupRequest{Name: "test"}, &ref); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if ref.ID != "g1" || ref.Name != "test-group" {
|
|
t.Fatalf("got %+v, want {ID:g1 Name:test-group}", ref)
|
|
}
|
|
})
|
|
|
|
t.Run("postJSON error status", func(t *testing.T) {
|
|
var ref GroupRef
|
|
err := client.postJSON(context.Background(), "/api/v1/admin/channels", nil, &ref)
|
|
if err == nil {
|
|
t.Fatal("expected error")
|
|
}
|
|
var httpErr *HTTPError
|
|
if !errors.As(err, &httpErr) {
|
|
t.Fatalf("expected HTTPError, got %T: %v", err, err)
|
|
}
|
|
if httpErr.StatusCode != 500 {
|
|
t.Fatalf("status code = %d, want 500", httpErr.StatusCode)
|
|
}
|
|
})
|
|
|
|
t.Run("getJSON success", func(t *testing.T) {
|
|
var ref GroupRef
|
|
if err := client.getJSON(context.Background(), "/api/v1/admin/groups", &ref); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
})
|
|
|
|
t.Run("getJSON error status", func(t *testing.T) {
|
|
var ref GroupRef
|
|
err := client.getJSON(context.Background(), "/bad/path", &ref)
|
|
if err == nil {
|
|
t.Fatal("expected error")
|
|
}
|
|
})
|
|
}
|
|
|
|
func TestCreateGroupWithMock(t *testing.T) {
|
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
var req struct {
|
|
Name string `json:"name"`
|
|
Platform string `json:"platform"`
|
|
RateMultiplier float64 `json:"rate_multiplier"`
|
|
SubscriptionType string `json:"subscription_type"`
|
|
}
|
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
|
t.Fatalf("decode request: %v", err)
|
|
}
|
|
if req.Name != "demo" || req.Platform != "openai" || req.RateMultiplier != 1.0 {
|
|
t.Fatalf("unexpected request: %+v", req)
|
|
}
|
|
w.Write([]byte(`{"data":{"id":"g1","name":"demo"}}`))
|
|
}))
|
|
defer srv.Close()
|
|
|
|
client, err := NewClient(srv.URL, WithAPIKey("k"))
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
ref, err := client.CreateGroup(context.Background(), CreateGroupRequest{Name: "demo", Platform: "openai", RateMultiplier: 1.0})
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if ref.ID != "g1" || ref.Name != "demo" {
|
|
t.Fatalf("got %+v", ref)
|
|
}
|
|
}
|
|
|
|
func TestCreateChannelWithMock(t *testing.T) {
|
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
var req struct {
|
|
Name string `json:"name"`
|
|
GroupIDs []int64 `json:"group_ids"`
|
|
ModelMapping map[string]map[string]string `json:"model_mapping"`
|
|
ModelPricing []struct {
|
|
Platform string `json:"platform"`
|
|
Models []string `json:"models"`
|
|
BillingMode string `json:"billing_mode"`
|
|
InputPrice *float64 `json:"input_price"`
|
|
OutputPrice *float64 `json:"output_price"`
|
|
CacheWritePrice *float64 `json:"cache_write_price"`
|
|
CacheReadPrice *float64 `json:"cache_read_price"`
|
|
ImageOutputPrice *float64 `json:"image_output_price"`
|
|
PerRequestPrice *float64 `json:"per_request_price"`
|
|
Intervals []any `json:"intervals"`
|
|
} `json:"model_pricing"`
|
|
RestrictModels bool `json:"restrict_models"`
|
|
BillingModelSource string `json:"billing_model_source"`
|
|
}
|
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
|
t.Fatalf("decode request: %v", err)
|
|
}
|
|
if req.Name != "ch" {
|
|
t.Fatalf("name = %q, want ch", req.Name)
|
|
}
|
|
if len(req.GroupIDs) != 1 || req.GroupIDs[0] != 101 {
|
|
t.Fatalf("group_ids = %v, want [101]", req.GroupIDs)
|
|
}
|
|
if req.ModelMapping["openai"]["deepseek-v4-pro"] != "deepseek-v4-pro" {
|
|
t.Fatalf("model_mapping = %+v, want openai/deepseek-v4-pro passthrough", req.ModelMapping)
|
|
}
|
|
if len(req.ModelPricing) != 1 {
|
|
t.Fatalf("model_pricing len = %d, want 1", len(req.ModelPricing))
|
|
}
|
|
if req.ModelPricing[0].Platform != "openai" || req.ModelPricing[0].BillingMode != "token" {
|
|
t.Fatalf("model_pricing[0] = %+v, want openai/token entry", req.ModelPricing[0])
|
|
}
|
|
if len(req.ModelPricing[0].Models) != 1 || req.ModelPricing[0].Models[0] != "deepseek-v4-pro" {
|
|
t.Fatalf("model_pricing[0].models = %v, want [deepseek-v4-pro]", req.ModelPricing[0].Models)
|
|
}
|
|
if !req.RestrictModels {
|
|
t.Fatal("restrict_models = false, want true")
|
|
}
|
|
if req.BillingModelSource != "channel_mapped" {
|
|
t.Fatalf("billing_model_source = %q, want channel_mapped", req.BillingModelSource)
|
|
}
|
|
w.Write([]byte(`{"data":{"id":201,"name":"ch"}}`))
|
|
}))
|
|
defer srv.Close()
|
|
client, _ := NewClient(srv.URL, WithAPIKey("k"))
|
|
ref, err := client.CreateChannel(context.Background(), CreateChannelRequest{
|
|
Name: "ch",
|
|
GroupIDs: []string{"101"},
|
|
ModelMapping: map[string]string{"deepseek-v4-pro": "deepseek-v4-pro"},
|
|
ModelPricing: []ChannelModelPricing{{Platform: "openai", Models: []string{"deepseek-v4-pro"}, BillingMode: "token"}},
|
|
RestrictModels: true,
|
|
BillingModelSource: "channel_mapped",
|
|
})
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if ref.ID != "201" {
|
|
t.Fatalf("id = %q, want 201", ref.ID)
|
|
}
|
|
}
|
|
|
|
func TestUpdateChannelWithMock(t *testing.T) {
|
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
if r.Method != http.MethodPut {
|
|
t.Fatalf("method = %s, want PUT", r.Method)
|
|
}
|
|
if r.URL.Path != "/api/v1/admin/channels/201" {
|
|
t.Fatalf("path = %s, want /api/v1/admin/channels/201", r.URL.Path)
|
|
}
|
|
var req struct {
|
|
ModelMapping map[string]map[string]string `json:"model_mapping"`
|
|
ModelPricing []struct {
|
|
Platform string `json:"platform"`
|
|
Models []string `json:"models"`
|
|
BillingMode string `json:"billing_mode"`
|
|
} `json:"model_pricing"`
|
|
RestrictModels bool `json:"restrict_models"`
|
|
BillingModelSource string `json:"billing_model_source"`
|
|
}
|
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
|
t.Fatalf("decode request: %v", err)
|
|
}
|
|
if req.ModelMapping["openai"]["deepseek-v4-pro"] != "deepseek-v4-pro" {
|
|
t.Fatalf("model_mapping = %+v, want openai/deepseek-v4-pro passthrough", req.ModelMapping)
|
|
}
|
|
if len(req.ModelPricing) != 1 || req.ModelPricing[0].Platform != "openai" || req.ModelPricing[0].BillingMode != "token" {
|
|
t.Fatalf("model_pricing = %+v, want openai/token entry", req.ModelPricing)
|
|
}
|
|
if !req.RestrictModels {
|
|
t.Fatal("restrict_models = false, want true")
|
|
}
|
|
if req.BillingModelSource != "channel_mapped" {
|
|
t.Fatalf("billing_model_source = %q, want channel_mapped", req.BillingModelSource)
|
|
}
|
|
w.Write([]byte(`{"data":{"id":201,"name":"ch"}}`))
|
|
}))
|
|
defer srv.Close()
|
|
client, _ := NewClient(srv.URL, WithAPIKey("k"))
|
|
if err := client.UpdateChannel(context.Background(), "201", CreateChannelRequest{
|
|
Name: "ch",
|
|
GroupIDs: []string{"101"},
|
|
ModelMapping: map[string]string{"deepseek-v4-pro": "deepseek-v4-pro"},
|
|
ModelPricing: []ChannelModelPricing{{Platform: "openai", Models: []string{"deepseek-v4-pro"}, BillingMode: "token"}},
|
|
RestrictModels: true,
|
|
BillingModelSource: "channel_mapped",
|
|
}); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
}
|
|
|
|
func TestCreateChannelRequestMarshalJSONDefaultsPricingPlatform(t *testing.T) {
|
|
t.Run("request platform", func(t *testing.T) {
|
|
payload, err := json.Marshal(CreateChannelRequest{
|
|
Name: "ch",
|
|
GroupIDs: []string{"101"},
|
|
Platform: "openai",
|
|
ModelMapping: map[string]string{"deepseek-v4-pro": "deepseek-v4-pro"},
|
|
ModelPricing: []ChannelModelPricing{{Models: []string{"deepseek-v4-pro"}, BillingMode: "token"}},
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("Marshal() error = %v", err)
|
|
}
|
|
var got struct {
|
|
ModelMapping map[string]map[string]string `json:"model_mapping"`
|
|
ModelPricing []struct {
|
|
Platform string `json:"platform"`
|
|
} `json:"model_pricing"`
|
|
}
|
|
if err := json.Unmarshal(payload, &got); err != nil {
|
|
t.Fatalf("Unmarshal() error = %v", err)
|
|
}
|
|
if got.ModelMapping["openai"]["deepseek-v4-pro"] != "deepseek-v4-pro" {
|
|
t.Fatalf("model_mapping = %+v, want openai/deepseek-v4-pro passthrough", got.ModelMapping)
|
|
}
|
|
if len(got.ModelPricing) != 1 || got.ModelPricing[0].Platform != "openai" {
|
|
t.Fatalf("model_pricing = %+v, want platform openai", got.ModelPricing)
|
|
}
|
|
})
|
|
|
|
t.Run("openai fallback", func(t *testing.T) {
|
|
payload, err := json.Marshal(CreateChannelRequest{
|
|
Name: "ch",
|
|
GroupIDs: []string{"101"},
|
|
ModelMapping: map[string]string{"deepseek-v4-pro": "deepseek-v4-pro"},
|
|
ModelPricing: []ChannelModelPricing{{Models: []string{"deepseek-v4-pro"}, BillingMode: "token"}},
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("Marshal() error = %v", err)
|
|
}
|
|
var got struct {
|
|
ModelPricing []struct {
|
|
Platform string `json:"platform"`
|
|
} `json:"model_pricing"`
|
|
}
|
|
if err := json.Unmarshal(payload, &got); err != nil {
|
|
t.Fatalf("Unmarshal() error = %v", err)
|
|
}
|
|
if len(got.ModelPricing) != 1 || got.ModelPricing[0].Platform != "openai" {
|
|
t.Fatalf("model_pricing = %+v, want platform openai fallback", got.ModelPricing)
|
|
}
|
|
})
|
|
}
|
|
|
|
func TestCreatePlanWithMock(t *testing.T) {
|
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
var req struct {
|
|
GroupID int64 `json:"group_id"`
|
|
Name string `json:"name"`
|
|
Price float64 `json:"price"`
|
|
ValidityDays int `json:"validity_days"`
|
|
ValidityUnit string `json:"validity_unit"`
|
|
}
|
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
|
t.Fatalf("decode request: %v", err)
|
|
}
|
|
if req.GroupID != 101 || req.Name != "plan" || req.Price != 19.9 || req.ValidityDays != 30 || req.ValidityUnit != "day" {
|
|
t.Fatalf("unexpected request: %+v", req)
|
|
}
|
|
w.Write([]byte(`{"data":{"id":301,"name":"plan"}}`))
|
|
}))
|
|
defer srv.Close()
|
|
client, _ := NewClient(srv.URL, WithAPIKey("k"))
|
|
ref, err := client.CreatePlan(context.Background(), CreatePlanRequest{GroupID: "101", Name: "plan", Price: 19.9, ValidityDays: 30, ValidityUnit: "day"})
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if ref.ID != "301" {
|
|
t.Fatalf("id = %q, want 301", ref.ID)
|
|
}
|
|
}
|
|
|
|
func TestDeleteWithMock(t *testing.T) {
|
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusNoContent)
|
|
}))
|
|
defer srv.Close()
|
|
client, _ := NewClient(srv.URL, WithAPIKey("k"))
|
|
|
|
t.Run("DeleteGroup", func(t *testing.T) {
|
|
if err := client.DeleteGroup(context.Background(), "g1"); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
})
|
|
t.Run("DeleteChannel", func(t *testing.T) {
|
|
if err := client.DeleteChannel(context.Background(), "c1"); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
})
|
|
t.Run("DeletePlan", func(t *testing.T) {
|
|
if err := client.DeletePlan(context.Background(), "p1"); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
})
|
|
t.Run("DeleteAccount", func(t *testing.T) {
|
|
if err := client.DeleteAccount(context.Background(), "a1"); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
})
|
|
}
|
|
|
|
func TestAssignSubscriptionWithMock(t *testing.T) {
|
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
var req struct {
|
|
UserID int64 `json:"user_id"`
|
|
GroupID int64 `json:"group_id"`
|
|
DurationDays int `json:"validity_days"`
|
|
}
|
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
|
t.Fatalf("decode request: %v", err)
|
|
}
|
|
if req.UserID != 501 || req.GroupID != 101 || req.DurationDays != 30 {
|
|
t.Fatalf("unexpected request: %+v", req)
|
|
}
|
|
w.Write([]byte(`{"data":{"id":401}}`))
|
|
}))
|
|
defer srv.Close()
|
|
client, _ := NewClient(srv.URL, WithAPIKey("k"))
|
|
ref, err := client.AssignSubscription(context.Background(), AssignSubscriptionRequest{UserID: "501", GroupID: "101", DurationDays: 30})
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if ref.ID != "401" {
|
|
t.Fatalf("id = %q", ref.ID)
|
|
}
|
|
}
|
|
|
|
func TestEnsureSubscriptionAccessWithMock(t *testing.T) {
|
|
var calls []string
|
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
calls = append(calls, r.Method+" "+r.URL.Path)
|
|
switch {
|
|
case r.Method == http.MethodGet && strings.HasPrefix(r.URL.RequestURI(), "/api/v1/admin/users?"):
|
|
w.Write([]byte(`{"data":{"items":[]}}`))
|
|
case r.Method == http.MethodPost && r.URL.Path == "/api/v1/admin/users":
|
|
w.Write([]byte(`{"data":{"id":84,"email":"relay-sub-user-1@sub2api.local"}}`))
|
|
case r.Method == http.MethodPut && r.URL.Path == "/api/v1/admin/users/84":
|
|
w.Write([]byte(`{"data":{"id":84}}`))
|
|
case r.Method == http.MethodPost && r.URL.Path == "/api/v1/admin/users/84/balance":
|
|
w.Write([]byte(`{"data":{"id":84}}`))
|
|
case r.Method == http.MethodPost && r.URL.Path == "/api/v1/admin/subscriptions/assign":
|
|
var req struct {
|
|
UserID int64 `json:"user_id"`
|
|
GroupID int64 `json:"group_id"`
|
|
DurationDays int `json:"validity_days"`
|
|
}
|
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
|
t.Fatalf("decode assign subscription request: %v", err)
|
|
}
|
|
if req.UserID != 84 || req.GroupID != 101 || req.DurationDays != 30 {
|
|
t.Fatalf("unexpected assign subscription request: %+v", req)
|
|
}
|
|
w.Write([]byte(`{"data":{"id":401}}`))
|
|
case r.Method == http.MethodPost && r.URL.Path == "/api/v1/auth/login":
|
|
w.Write([]byte(`{"data":{"access_token":"user-jwt"}}`))
|
|
case r.Method == http.MethodPost && r.URL.Path == "/api/v1/keys":
|
|
var req map[string]any
|
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
|
t.Fatalf("decode managed key request: %v", err)
|
|
}
|
|
if _, ok := req["group_id"]; ok {
|
|
t.Fatalf("managed key request unexpectedly carried group_id: %+v", req)
|
|
}
|
|
w.Write([]byte(`{"data":{"id":501,"key":"sk-relay-key","name":"managed-key"}}`))
|
|
case r.Method == http.MethodPut && r.URL.Path == "/api/v1/admin/api-keys/501":
|
|
w.Write([]byte(`{"data":{"api_key":{"id":501}}}`))
|
|
default:
|
|
w.WriteHeader(http.StatusNotFound)
|
|
}
|
|
}))
|
|
defer srv.Close()
|
|
client, _ := NewClient(srv.URL, WithBearerToken("admin-token"))
|
|
ref, err := client.EnsureSubscriptionAccess(context.Background(), EnsureSubscriptionAccessRequest{UserSelector: "crm-user-1", GroupID: "101"})
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if ref.UserID != "84" {
|
|
t.Fatalf("user id = %q, want 84", ref.UserID)
|
|
}
|
|
if !strings.HasPrefix(ref.APIKey, "sk-relay-") {
|
|
t.Fatalf("api key = %q, want managed sk-relay-* key", ref.APIKey)
|
|
}
|
|
if len(calls) < 7 {
|
|
t.Fatalf("calls = %v, want managed subscription setup sequence", calls)
|
|
}
|
|
}
|
|
|
|
func TestCheckGatewayAccessWithMock(t *testing.T) {
|
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
if got := r.Header.Get("Authorization"); got != "Bearer gk" {
|
|
t.Fatalf("Authorization = %q, want %q", got, "Bearer gk")
|
|
}
|
|
if got := r.Header.Get("x-api-key"); got != "" {
|
|
t.Fatalf("x-api-key = %q, want empty", got)
|
|
}
|
|
w.Write([]byte(`{"data":[{"id":"gpt-4"},{"id":"claude-3"}]}`))
|
|
}))
|
|
defer srv.Close()
|
|
client, _ := NewClient(srv.URL, WithAPIKey("k"))
|
|
result, err := client.CheckGatewayAccess(context.Background(), GatewayAccessCheckRequest{APIKey: "gk", ExpectedModel: "gpt-4"})
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if !result.OK {
|
|
t.Fatal("expected OK=true")
|
|
}
|
|
if !result.HasExpectedModel {
|
|
t.Fatal("expected HasExpectedModel=true")
|
|
}
|
|
}
|
|
|
|
func TestCheckGatewayCompletionWithMock(t *testing.T) {
|
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
if r.URL.Path != "/v1/chat/completions" {
|
|
t.Fatalf("path = %q, want /v1/chat/completions", r.URL.Path)
|
|
}
|
|
if got := r.Header.Get("Authorization"); got != "Bearer gk" {
|
|
t.Fatalf("Authorization = %q, want %q", got, "Bearer gk")
|
|
}
|
|
if got := r.Header.Get("x-api-key"); got != "" {
|
|
t.Fatalf("x-api-key = %q, want empty", got)
|
|
}
|
|
var payload struct {
|
|
Model string `json:"model"`
|
|
MaxTokens int `json:"max_tokens"`
|
|
}
|
|
if err := json.NewDecoder(r.Body).Decode(&payload); err != nil {
|
|
t.Fatalf("decode request: %v", err)
|
|
}
|
|
if payload.Model != "gpt-4" {
|
|
t.Fatalf("model = %q, want gpt-4", payload.Model)
|
|
}
|
|
if payload.MaxTokens != 8 {
|
|
t.Fatalf("max_tokens = %d, want 8", payload.MaxTokens)
|
|
}
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.Write([]byte(`{"choices":[{"message":{"content":"pong"}}]}`))
|
|
}))
|
|
defer srv.Close()
|
|
client, _ := NewClient(srv.URL, WithAPIKey("k"))
|
|
result, err := client.CheckGatewayCompletion(context.Background(), GatewayCompletionCheckRequest{APIKey: "gk", Model: "gpt-4"})
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if !result.OK {
|
|
t.Fatal("expected completion OK=true")
|
|
}
|
|
if result.StatusCode != 200 {
|
|
t.Fatalf("status = %d, want 200", result.StatusCode)
|
|
}
|
|
if result.ContentType != "application/json" {
|
|
t.Fatalf("content type = %q, want application/json", result.ContentType)
|
|
}
|
|
}
|
|
|
|
func TestDisableOpenAIResponsesAPIWithMock(t *testing.T) {
|
|
var calls []string
|
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
calls = append(calls, r.Method+" "+r.URL.Path)
|
|
if r.Method != http.MethodPut {
|
|
t.Fatalf("method = %q, want PUT", r.Method)
|
|
}
|
|
var payload struct {
|
|
Extra map[string]any `json:"extra"`
|
|
}
|
|
if err := json.NewDecoder(r.Body).Decode(&payload); err != nil {
|
|
t.Fatalf("decode request: %v", err)
|
|
}
|
|
if got, ok := payload.Extra["openai_responses_supported"].(bool); !ok || got {
|
|
t.Fatalf("openai_responses_supported = %+v, want false", payload.Extra["openai_responses_supported"])
|
|
}
|
|
w.Write([]byte(`{"data":{"id":1}}`))
|
|
}))
|
|
defer srv.Close()
|
|
|
|
client, _ := NewClient(srv.URL, WithAPIKey("k"))
|
|
if err := client.DisableOpenAIResponsesAPI(context.Background(), []string{"101", "101", " ", "102"}); err != nil {
|
|
t.Fatalf("DisableOpenAIResponsesAPI() error = %v", err)
|
|
}
|
|
if len(calls) != 2 {
|
|
t.Fatalf("calls = %v, want 2 unique account updates", calls)
|
|
}
|
|
if calls[0] != "PUT /api/v1/admin/accounts/101" {
|
|
t.Fatalf("first call = %q, want PUT /api/v1/admin/accounts/101", calls[0])
|
|
}
|
|
if calls[1] != "PUT /api/v1/admin/accounts/102" {
|
|
t.Fatalf("second call = %q, want PUT /api/v1/admin/accounts/102", calls[1])
|
|
}
|
|
}
|
|
|
|
func TestBatchCreateAccountsWithMock(t *testing.T) {
|
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
var req struct {
|
|
Accounts []struct {
|
|
Name string `json:"name"`
|
|
Platform string `json:"platform"`
|
|
Type string `json:"type"`
|
|
Credentials map[string]any `json:"credentials"`
|
|
GroupIDs []int64 `json:"group_ids"`
|
|
} `json:"accounts"`
|
|
}
|
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
|
t.Fatalf("decode request: %v", err)
|
|
}
|
|
if len(req.Accounts) != 1 {
|
|
t.Fatalf("accounts len = %d, want 1", len(req.Accounts))
|
|
}
|
|
acct := req.Accounts[0]
|
|
if acct.Name != "acct1" || acct.Platform != "openai" || acct.Type != "apikey" {
|
|
t.Fatalf("unexpected account metadata: %+v", acct)
|
|
}
|
|
if len(acct.GroupIDs) != 1 || acct.GroupIDs[0] != 101 {
|
|
t.Fatalf("group_ids = %v, want [101]", acct.GroupIDs)
|
|
}
|
|
rawMapping, ok := acct.Credentials["model_mapping"].(map[string]any)
|
|
if !ok {
|
|
t.Fatalf("credentials = %+v, want model_mapping map", acct.Credentials)
|
|
}
|
|
if got, _ := rawMapping["deepseek-v4-pro"].(string); got != "deepseek-v4-pro" {
|
|
t.Fatalf("model_mapping = %+v, want deepseek-v4-pro passthrough", rawMapping)
|
|
}
|
|
w.Write([]byte(`{"data":[{"id":601,"name":"acct1"}]}`))
|
|
}))
|
|
defer srv.Close()
|
|
client, _ := NewClient(srv.URL, WithAPIKey("k"))
|
|
refs, err := client.BatchCreateAccounts(context.Background(), BatchCreateAccountsRequest{
|
|
Accounts: []CreateAccountRequest{{Name: "acct1", Platform: "openai", Type: "apikey", GroupIDs: []string{"101"}, Credentials: map[string]any{"api_key": "sk-test", "base_url": "https://api.example.com", "model_mapping": map[string]string{"deepseek-v4-pro": "deepseek-v4-pro"}}}},
|
|
})
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if len(refs) != 1 || refs[0].ID != "601" {
|
|
t.Fatalf("got %+v", refs)
|
|
}
|
|
}
|
|
|
|
func TestProbeCapabilitiesWithMock(t *testing.T) {
|
|
callCount := 0
|
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
callCount++
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.WriteHeader(http.StatusOK)
|
|
_, _ = w.Write([]byte(`{"data":[]}`))
|
|
}))
|
|
defer srv.Close()
|
|
client, _ := NewClient(srv.URL, WithAPIKey("k"))
|
|
caps, err := client.ProbeCapabilities(context.Background())
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if !caps.Groups || !caps.Channels || !caps.Plans || !caps.Accounts || !caps.AccountTest || !caps.AccountModels || !caps.Subscriptions {
|
|
t.Fatalf("all capabilities should be true, got %+v", caps)
|
|
}
|
|
if callCount != 7 {
|
|
t.Fatalf("callCount = %d, want 7", callCount)
|
|
}
|
|
}
|
|
|
|
func TestProbeCapabilitiesDoesNotTreat404AsSupportForAccountOrSubscriptionRoutes(t *testing.T) {
|
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.Header().Set("Content-Type", "application/json")
|
|
switch r.URL.Path {
|
|
case "/api/v1/admin/groups", "/api/v1/admin/channels", "/api/v1/admin/payment/plans", "/api/v1/admin/accounts":
|
|
w.WriteHeader(http.StatusOK)
|
|
_, _ = w.Write([]byte(`{"data":[]}`))
|
|
case "/api/v1/admin/accounts/__probe__/test", "/api/v1/admin/accounts/__probe__/models", "/api/v1/admin/subscriptions/assign":
|
|
w.WriteHeader(http.StatusNotFound)
|
|
_, _ = w.Write([]byte(`{"error":"not found"}`))
|
|
default:
|
|
w.WriteHeader(http.StatusNotFound)
|
|
}
|
|
}))
|
|
defer srv.Close()
|
|
|
|
client, _ := NewClient(srv.URL, WithAPIKey("k"))
|
|
caps, err := client.ProbeCapabilities(context.Background())
|
|
if err != nil {
|
|
t.Fatalf("ProbeCapabilities() error = %v", err)
|
|
}
|
|
if caps.AccountTest {
|
|
t.Fatal("AccountTest = true, want false on 404 probe route")
|
|
}
|
|
if caps.AccountModels {
|
|
t.Fatal("AccountModels = true, want false on 404 probe route")
|
|
}
|
|
if caps.Subscriptions {
|
|
t.Fatal("Subscriptions = true, want false on 404 probe route")
|
|
}
|
|
}
|
|
|
|
func TestListManagedResourcesWithMock(t *testing.T) {
|
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.Write([]byte(`{"data":{"items":[
|
|
{"id":"r1","name":"resource-1"}
|
|
]}}`))
|
|
}))
|
|
defer srv.Close()
|
|
client, _ := NewClient(srv.URL, WithAPIKey("k"))
|
|
snapshot, err := client.ListManagedResources(context.Background(), ListManagedResourcesRequest{})
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if len(snapshot.Groups) != 1 {
|
|
t.Fatalf("expected 1 group, got %d", len(snapshot.Groups))
|
|
}
|
|
}
|
|
|
|
func TestListManagedResourcesLoadsAllAccountPages(t *testing.T) {
|
|
accountPages := 0
|
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
switch r.URL.Path {
|
|
case "/api/v1/admin/groups", "/api/v1/admin/channels":
|
|
_, _ = w.Write([]byte(`{"data":{"items":[{"id":"r1","name":"resource-1"}],"total":1,"page":1,"page_size":20,"pages":1}}`))
|
|
case "/api/v1/admin/payment/plans":
|
|
_, _ = w.Write([]byte(`{"data":[{"id":"plan_1","name":"plan-1"}]}`))
|
|
case "/api/v1/admin/accounts":
|
|
accountPages++
|
|
page := r.URL.Query().Get("page")
|
|
if page == "" {
|
|
page = "1"
|
|
}
|
|
if got := r.URL.Query().Get("page_size"); got != "100" {
|
|
t.Fatalf("page_size = %q, want 100", got)
|
|
}
|
|
switch page {
|
|
case "1":
|
|
_, _ = w.Write([]byte(`{"data":{"items":[{"id":"account_1","name":"deepseek-01"}],"total":2,"page":1,"page_size":100,"pages":2}}`))
|
|
case "2":
|
|
_, _ = w.Write([]byte(`{"data":{"items":[{"id":"account_2","name":"deepseek-02"}],"total":2,"page":2,"page_size":100,"pages":2}}`))
|
|
default:
|
|
t.Fatalf("unexpected accounts page %q", page)
|
|
}
|
|
default:
|
|
http.NotFound(w, r)
|
|
}
|
|
}))
|
|
defer srv.Close()
|
|
client, _ := NewClient(srv.URL, WithAPIKey("k"))
|
|
snapshot, err := client.ListManagedResources(context.Background(), ListManagedResourcesRequest{AccountNamePrefix: "deepseek-"})
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if accountPages != 2 {
|
|
t.Fatalf("account pages fetched = %d, want 2", accountPages)
|
|
}
|
|
if len(snapshot.Accounts) != 2 || snapshot.Accounts[0].ID != "account_1" || snapshot.Accounts[1].ID != "account_2" {
|
|
t.Fatalf("Accounts = %+v, want both paged accounts", snapshot.Accounts)
|
|
}
|
|
}
|
|
|
|
func TestTestAccountWithMock(t *testing.T) {
|
|
var requestBody map[string]any
|
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
if err := json.NewDecoder(r.Body).Decode(&requestBody); err != nil {
|
|
t.Fatalf("decode request body: %v", err)
|
|
}
|
|
w.Write([]byte("data: {\"status\":\"passed\",\"ok\":true}\n"))
|
|
}))
|
|
defer srv.Close()
|
|
client, _ := NewClient(srv.URL, WithAPIKey("k"))
|
|
result, err := client.TestAccount(context.Background(), "a1", "MiniMax-M2.7-highspeed")
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if !result.OK {
|
|
t.Fatal("expected OK=true")
|
|
}
|
|
if got := requestBody["model_id"]; got != "MiniMax-M2.7-highspeed" {
|
|
t.Fatalf("model_id = %#v, want MiniMax-M2.7-highspeed", got)
|
|
}
|
|
}
|
|
|
|
func TestTestAccountWithMockSSEError(t *testing.T) {
|
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.Header().Set("Content-Type", "text/event-stream")
|
|
w.Write([]byte("data: {\"type\":\"test_start\",\"model\":\"MiniMax-M2.7-highspeed\"}\n\n"))
|
|
w.Write([]byte("data: {\"type\":\"error\",\"error\":\"账号本身可正常使用,但当前测试接口仅支持 Responses API 路径。请直接通过实际 API 调用验证。\"}\n\n"))
|
|
}))
|
|
defer srv.Close()
|
|
|
|
client, _ := NewClient(srv.URL, WithAPIKey("k"))
|
|
result, err := client.TestAccount(context.Background(), "a1", "")
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if result.OK {
|
|
t.Fatal("expected OK=false for SSE error event")
|
|
}
|
|
if result.Status != "failed" {
|
|
t.Fatalf("Status = %q, want failed", result.Status)
|
|
}
|
|
if !strings.Contains(result.Message, "测试接口仅支持 Responses API 路径") {
|
|
t.Fatalf("Message = %q, want propagated SSE error message", result.Message)
|
|
}
|
|
}
|
|
|
|
func TestGetAccountModelsWithMock(t *testing.T) {
|
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.Write([]byte(`{"data":[{"id":"m1","display_name":"M1","type":"chat"}]}`))
|
|
}))
|
|
defer srv.Close()
|
|
client, _ := NewClient(srv.URL, WithAPIKey("k"))
|
|
models, err := client.GetAccountModels(context.Background(), "a1")
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if len(models) != 1 || models[0].ID != "m1" {
|
|
t.Fatalf("got %+v", models)
|
|
}
|
|
}
|