Files
2026-05-23 17:06:52 +08:00

115 lines
3.2 KiB
Go

package sub2api
import (
"context"
"encoding/json"
"net/http"
"strings"
)
type GatewayAccessCheckRequest struct {
APIKey string
ExpectedModel string
}
type GatewayAccessResult struct {
OK bool `json:"ok"`
StatusCode int `json:"status_code"`
Models []string `json:"models"`
HasExpectedModel bool `json:"has_expected_model"`
CompletionOK bool `json:"completion_ok"`
CompletionStatus int `json:"completion_status"`
CompletionType string `json:"completion_content_type,omitempty"`
CompletionBody string `json:"completion_body_preview,omitempty"`
EffectiveProbeAPIKey string `json:"-"`
EffectiveProbeKeySource string `json:"-"`
}
func (c *Client) CheckGatewayAccess(ctx context.Context, req GatewayAccessCheckRequest) (GatewayAccessResult, error) {
gatewayClient := *c
gatewayClient.apiKey = ""
gatewayClient.bearerToken = strings.TrimSpace(req.APIKey)
statusCode, _, body, err := gatewayClient.perform(ctx, http.MethodGet, "/v1/models", nil)
if err != nil {
return GatewayAccessResult{}, err
}
result := GatewayAccessResult{StatusCode: statusCode, OK: statusCode >= http.StatusOK && statusCode < http.StatusMultipleChoices}
if !result.OK {
return result, nil
}
result.Models = decodeGatewayModelIDs(body)
for _, modelID := range result.Models {
if modelID == strings.TrimSpace(req.ExpectedModel) {
result.HasExpectedModel = true
break
}
}
return result, nil
}
func (c *Client) CheckGatewayCompletion(ctx context.Context, req GatewayCompletionCheckRequest) (GatewayCompletionResult, error) {
gatewayClient := *c
gatewayClient.apiKey = ""
gatewayClient.bearerToken = strings.TrimSpace(req.APIKey)
model := strings.TrimSpace(req.Model)
if model == "" {
return GatewayCompletionResult{}, nil
}
prompt := strings.TrimSpace(req.Prompt)
if prompt == "" {
prompt = "ping"
}
maxTokens := req.MaxTokens
if maxTokens <= 0 {
maxTokens = 8
}
payload := map[string]any{
"model": model,
"messages": []map[string]string{
{"role": "user", "content": prompt},
},
"max_tokens": maxTokens,
"temperature": 0,
}
statusCode, headers, body, err := gatewayClient.perform(ctx, http.MethodPost, "/v1/chat/completions", payload)
if err != nil {
return GatewayCompletionResult{}, err
}
return GatewayCompletionResult{
OK: statusCode >= http.StatusOK && statusCode < http.StatusMultipleChoices,
StatusCode: statusCode,
ContentType: strings.TrimSpace(headers.Get("Content-Type")),
BodyPreview: previewGatewayBody(body, 400),
}, nil
}
func decodeGatewayModelIDs(body []byte) []string {
var payload struct {
Data []struct {
ID string `json:"id"`
} `json:"data"`
}
if err := json.Unmarshal(body, &payload); err == nil && len(payload.Data) > 0 {
models := make([]string, 0, len(payload.Data))
for _, item := range payload.Data {
if id := strings.TrimSpace(item.ID); id != "" {
models = append(models, id)
}
}
return models
}
return nil
}
func previewGatewayBody(body []byte, limit int) string {
trimmed := strings.TrimSpace(string(body))
if limit <= 0 || len(trimmed) <= limit {
return trimmed
}
return trimmed[:limit]
}