115 lines
3.2 KiB
Go
115 lines
3.2 KiB
Go
package sub2api
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"net/http"
|
|
"strings"
|
|
)
|
|
|
|
type GatewayAccessCheckRequest struct {
|
|
APIKey string
|
|
ExpectedModel string
|
|
}
|
|
|
|
type GatewayAccessResult struct {
|
|
OK bool `json:"ok"`
|
|
StatusCode int `json:"status_code"`
|
|
Models []string `json:"models"`
|
|
HasExpectedModel bool `json:"has_expected_model"`
|
|
CompletionOK bool `json:"completion_ok"`
|
|
CompletionStatus int `json:"completion_status"`
|
|
CompletionType string `json:"completion_content_type,omitempty"`
|
|
CompletionBody string `json:"completion_body_preview,omitempty"`
|
|
EffectiveProbeAPIKey string `json:"-"`
|
|
EffectiveProbeKeySource string `json:"-"`
|
|
}
|
|
|
|
func (c *Client) CheckGatewayAccess(ctx context.Context, req GatewayAccessCheckRequest) (GatewayAccessResult, error) {
|
|
gatewayClient := *c
|
|
gatewayClient.apiKey = ""
|
|
gatewayClient.bearerToken = strings.TrimSpace(req.APIKey)
|
|
|
|
statusCode, _, body, err := gatewayClient.perform(ctx, http.MethodGet, "/v1/models", nil)
|
|
if err != nil {
|
|
return GatewayAccessResult{}, err
|
|
}
|
|
|
|
result := GatewayAccessResult{StatusCode: statusCode, OK: statusCode >= http.StatusOK && statusCode < http.StatusMultipleChoices}
|
|
if !result.OK {
|
|
return result, nil
|
|
}
|
|
result.Models = decodeGatewayModelIDs(body)
|
|
for _, modelID := range result.Models {
|
|
if modelID == strings.TrimSpace(req.ExpectedModel) {
|
|
result.HasExpectedModel = true
|
|
break
|
|
}
|
|
}
|
|
return result, nil
|
|
}
|
|
|
|
func (c *Client) CheckGatewayCompletion(ctx context.Context, req GatewayCompletionCheckRequest) (GatewayCompletionResult, error) {
|
|
gatewayClient := *c
|
|
gatewayClient.apiKey = ""
|
|
gatewayClient.bearerToken = strings.TrimSpace(req.APIKey)
|
|
|
|
model := strings.TrimSpace(req.Model)
|
|
if model == "" {
|
|
return GatewayCompletionResult{}, nil
|
|
}
|
|
prompt := strings.TrimSpace(req.Prompt)
|
|
if prompt == "" {
|
|
prompt = "ping"
|
|
}
|
|
maxTokens := req.MaxTokens
|
|
if maxTokens <= 0 {
|
|
maxTokens = 8
|
|
}
|
|
payload := map[string]any{
|
|
"model": model,
|
|
"messages": []map[string]string{
|
|
{"role": "user", "content": prompt},
|
|
},
|
|
"max_tokens": maxTokens,
|
|
"temperature": 0,
|
|
}
|
|
|
|
statusCode, headers, body, err := gatewayClient.perform(ctx, http.MethodPost, "/v1/chat/completions", payload)
|
|
if err != nil {
|
|
return GatewayCompletionResult{}, err
|
|
}
|
|
return GatewayCompletionResult{
|
|
OK: statusCode >= http.StatusOK && statusCode < http.StatusMultipleChoices,
|
|
StatusCode: statusCode,
|
|
ContentType: strings.TrimSpace(headers.Get("Content-Type")),
|
|
BodyPreview: previewGatewayBody(body, 400),
|
|
}, nil
|
|
}
|
|
|
|
func decodeGatewayModelIDs(body []byte) []string {
|
|
var payload struct {
|
|
Data []struct {
|
|
ID string `json:"id"`
|
|
} `json:"data"`
|
|
}
|
|
if err := json.Unmarshal(body, &payload); err == nil && len(payload.Data) > 0 {
|
|
models := make([]string, 0, len(payload.Data))
|
|
for _, item := range payload.Data {
|
|
if id := strings.TrimSpace(item.ID); id != "" {
|
|
models = append(models, id)
|
|
}
|
|
}
|
|
return models
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func previewGatewayBody(body []byte, limit int) string {
|
|
trimmed := strings.TrimSpace(string(body))
|
|
if limit <= 0 || len(trimmed) <= limit {
|
|
return trimmed
|
|
}
|
|
return trimmed[:limit]
|
|
}
|