first commit

This commit is contained in:
2025-11-21 16:32:35 +08:00
parent a54424afba
commit ce361482f4
26 changed files with 2445 additions and 0 deletions

View File

@@ -0,0 +1,315 @@
package handler
import (
"bytes"
"context"
"encoding/json"
"fmt"
"net"
"net/http"
"sync"
"time"
"linkmaster-node/internal/config"
"linkmaster-node/internal/continuous"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
)
var continuousTasks = make(map[string]*ContinuousTask)
var taskMutex sync.RWMutex
var backendURL string
var logger *zap.Logger
func InitContinuousHandler(cfg *config.Config) {
backendURL = cfg.Backend.URL
logger, _ = zap.NewProduction()
}
type ContinuousTask struct {
TaskID string
Type string
Target string
Interval time.Duration
MaxDuration time.Duration
StartTime time.Time
LastRequest time.Time
StopCh chan struct{}
IsRunning bool
pingTask *continuous.PingTask
tcpingTask *continuous.TCPingTask
}
func HandleContinuousStart(c *gin.Context) {
var req struct {
Type string `json:"type" binding:"required"`
Target string `json:"target" binding:"required"`
Interval int `json:"interval"` // 秒
MaxDuration int `json:"max_duration"` // 分钟
}
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
// 生成任务ID
taskID := generateTaskID()
// 设置默认值
interval := 10 * time.Second
if req.Interval > 0 {
interval = time.Duration(req.Interval) * time.Second
}
maxDuration := 60 * time.Minute
if req.MaxDuration > 0 {
maxDuration = time.Duration(req.MaxDuration) * time.Minute
}
// 创建任务
task := &ContinuousTask{
TaskID: taskID,
Type: req.Type,
Target: req.Target,
Interval: interval,
MaxDuration: maxDuration,
StartTime: time.Now(),
LastRequest: time.Now(),
StopCh: make(chan struct{}),
IsRunning: true,
}
// 根据类型创建对应的任务
if req.Type == "ping" {
pingTask := continuous.NewPingTask(taskID, req.Target, interval, maxDuration)
task.pingTask = pingTask
} else if req.Type == "tcping" {
tcpingTask, err := continuous.NewTCPingTask(taskID, req.Target, interval, maxDuration)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
task.tcpingTask = tcpingTask
} else {
c.JSON(http.StatusBadRequest, gin.H{"error": "不支持的持续测试类型"})
return
}
taskMutex.Lock()
continuousTasks[taskID] = task
taskMutex.Unlock()
// 启动持续测试goroutine
ctx := context.Background()
if task.pingTask != nil {
go task.pingTask.Start(ctx, func(result map[string]interface{}) {
pushResultToBackend(taskID, result)
})
} else if task.tcpingTask != nil {
go task.tcpingTask.Start(ctx, func(result map[string]interface{}) {
pushResultToBackend(taskID, result)
})
}
c.JSON(http.StatusOK, gin.H{
"task_id": taskID,
})
}
func HandleContinuousStop(c *gin.Context) {
var req struct {
TaskID string `json:"task_id" binding:"required"`
}
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
taskMutex.Lock()
task, exists := continuousTasks[req.TaskID]
if exists {
task.IsRunning = false
if task.pingTask != nil {
task.pingTask.Stop()
}
if task.tcpingTask != nil {
task.tcpingTask.Stop()
}
close(task.StopCh)
delete(continuousTasks, req.TaskID)
}
taskMutex.Unlock()
if !exists {
c.JSON(http.StatusNotFound, gin.H{"error": "任务不存在"})
return
}
c.JSON(http.StatusOK, gin.H{"message": "任务已停止"})
}
func HandleContinuousStatus(c *gin.Context) {
taskID := c.Query("task_id")
if taskID == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "task_id参数缺失"})
return
}
taskMutex.RLock()
task, exists := continuousTasks[taskID]
if exists {
// 更新LastRequest时间
task.LastRequest = time.Now()
if task.pingTask != nil {
task.pingTask.UpdateLastRequest()
}
if task.tcpingTask != nil {
task.tcpingTask.UpdateLastRequest()
}
}
taskMutex.RUnlock()
if !exists {
c.JSON(http.StatusNotFound, gin.H{"error": "任务不存在"})
return
}
c.JSON(http.StatusOK, gin.H{
"task_id": task.TaskID,
"is_running": task.IsRunning,
"start_time": task.StartTime,
"last_request": task.LastRequest,
})
}
func pushResultToBackend(taskID string, result map[string]interface{}) {
// 推送结果到后端
url := fmt.Sprintf("%s/api/public/node/continuous/result", backendURL)
// 获取本机IP
nodeIP := getLocalIP()
data := map[string]interface{}{
"task_id": taskID,
"node_ip": nodeIP,
"result": result,
}
jsonData, err := json.Marshal(data)
if err != nil {
logger.Error("序列化结果失败", zap.Error(err))
return
}
req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData))
if err != nil {
logger.Error("创建请求失败", zap.Error(err))
return
}
req.Header.Set("Content-Type", "application/json")
client := &http.Client{Timeout: 5 * time.Second}
resp, err := client.Do(req)
if err != nil {
logger.Warn("推送结果失败", zap.Error(err))
// 如果推送失败,停止任务
taskMutex.Lock()
if task, exists := continuousTasks[taskID]; exists {
task.IsRunning = false
if task.pingTask != nil {
task.pingTask.Stop()
}
if task.tcpingTask != nil {
task.tcpingTask.Stop()
}
delete(continuousTasks, taskID)
}
taskMutex.Unlock()
return
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
logger.Warn("推送结果失败", zap.Int("status", resp.StatusCode))
// 如果推送失败,停止任务
taskMutex.Lock()
if task, exists := continuousTasks[taskID]; exists {
task.IsRunning = false
if task.pingTask != nil {
task.pingTask.Stop()
}
if task.tcpingTask != nil {
task.tcpingTask.Stop()
}
delete(continuousTasks, taskID)
}
taskMutex.Unlock()
}
}
func getLocalIP() string {
// 简化实现返回第一个非回环IP
// 实际应该获取外网IP
addrs, err := net.InterfaceAddrs()
if err != nil {
return "127.0.0.1"
}
for _, addr := range addrs {
if ipNet, ok := addr.(*net.IPNet); ok && !ipNet.IP.IsLoopback() {
if ipNet.IP.To4() != nil {
return ipNet.IP.String()
}
}
}
return "127.0.0.1"
}
func generateTaskID() string {
return fmt.Sprintf("task_%d", time.Now().UnixNano())
}
// 定期清理超时任务
func StartTaskCleanup() {
ticker := time.NewTicker(1 * time.Minute)
go func() {
for range ticker.C {
now := time.Now()
taskMutex.Lock()
for taskID, task := range continuousTasks {
// 检查最大运行时长
if now.Sub(task.StartTime) > task.MaxDuration {
logger.Info("任务达到最大运行时长,自动停止", zap.String("task_id", taskID))
task.IsRunning = false
if task.pingTask != nil {
task.pingTask.Stop()
}
if task.tcpingTask != nil {
task.tcpingTask.Stop()
}
delete(continuousTasks, taskID)
continue
}
// 检查无客户端连接30分钟无请求
if now.Sub(task.LastRequest) > 30*time.Minute {
logger.Info("任务无客户端连接,自动停止", zap.String("task_id", taskID))
task.IsRunning = false
if task.pingTask != nil {
task.pingTask.Stop()
}
if task.tcpingTask != nil {
task.tcpingTask.Stop()
}
delete(continuousTasks, taskID)
}
}
taskMutex.Unlock()
}
}()
}

45
internal/handler/dns.go Normal file
View File

@@ -0,0 +1,45 @@
package handler
import (
"net"
"time"
"github.com/gin-gonic/gin"
)
func handleDns(c *gin.Context, url string, params map[string]interface{}) {
// 执行DNS查询
start := time.Now()
ips, err := net.LookupIP(url)
lookupTime := time.Since(start).Milliseconds()
if err != nil {
c.JSON(200, gin.H{
"type": "ceDns",
"url": url,
"error": err.Error(),
})
return
}
// 格式化IP列表
ipList := make([]map[string]interface{}, 0)
for _, ip := range ips {
ipType := "A"
if ip.To4() == nil {
ipType = "AAAA"
}
ipList = append(ipList, map[string]interface{}{
"type": ipType,
"ip": ip.String(),
})
}
c.JSON(200, gin.H{
"type": "ceDns",
"url": url,
"ips": ipList,
"lookup_time": lookupTime,
})
}

View File

@@ -0,0 +1,84 @@
package handler
import (
"net"
"os/exec"
"sync"
"github.com/gin-gonic/gin"
)
func handleFindPing(c *gin.Context, url string, params map[string]interface{}) {
// url应该是CIDR格式如 8.8.8.0/24
cidr := url
if cidrParam, ok := params["cidr"].(string); ok && cidrParam != "" {
cidr = cidrParam
}
// 解析CIDR
_, ipNet, err := net.ParseCIDR(cidr)
if err != nil {
c.JSON(200, gin.H{
"type": "ceFindPing",
"error": "无效的CIDR格式",
})
return
}
// 生成IP列表
var ipList []string
for ip := ipNet.IP.Mask(ipNet.Mask); ipNet.Contains(ip); incIP(ip) {
ipList = append(ipList, ip.String())
}
// 移除网络地址和广播地址
if len(ipList) > 2 {
ipList = ipList[1 : len(ipList)-1]
}
// 并发ping测试
var wg sync.WaitGroup
var mu sync.Mutex
aliveIPs := make([]string, 0)
// 限制并发数
semaphore := make(chan struct{}, 50)
for _, ip := range ipList {
wg.Add(1)
semaphore <- struct{}{}
go func(ipAddr string) {
defer wg.Done()
defer func() { <-semaphore }()
// 执行ping只ping一次快速检测
cmd := exec.Command("ping", "-c", "1", "-W", "1", ipAddr)
err := cmd.Run()
if err == nil {
mu.Lock()
aliveIPs = append(aliveIPs, ipAddr)
mu.Unlock()
}
}(ip)
}
wg.Wait()
c.JSON(200, gin.H{
"type": "ceFindPing",
"cidr": cidr,
"alive_ips": aliveIPs,
"alive_count": len(aliveIPs),
"total_ips": len(ipList),
})
}
func incIP(ip net.IP) {
for j := len(ip) - 1; j >= 0; j-- {
ip[j]++
if ip[j] > 0 {
break
}
}
}

32
internal/handler/get.go Normal file
View File

@@ -0,0 +1,32 @@
package handler
import (
"net/http"
"time"
"github.com/gin-gonic/gin"
)
func handleGet(c *gin.Context, url string, params map[string]interface{}) {
// TODO: 实现HTTP GET测试
// 这里先返回一个简单的响应
c.JSON(http.StatusOK, gin.H{
"type": "ceGet",
"url": url,
"statuscode": 200,
"totaltime": time.Since(time.Now()).Milliseconds(),
"response": "OK",
})
}
func handlePost(c *gin.Context, url string, params map[string]interface{}) {
// TODO: 实现HTTP POST测试
c.JSON(http.StatusOK, gin.H{
"type": "cePost",
"url": url,
"statuscode": 200,
"totaltime": time.Since(time.Now()).Milliseconds(),
"response": "OK",
})
}

85
internal/handler/ping.go Normal file
View File

@@ -0,0 +1,85 @@
package handler
import (
"net"
"os/exec"
"strconv"
"strings"
"github.com/gin-gonic/gin"
)
func handlePing(c *gin.Context, url string, params map[string]interface{}) {
// 执行ping命令
cmd := exec.Command("ping", "-c", "4", url)
output, err := cmd.CombinedOutput()
if err != nil {
c.JSON(200, gin.H{
"type": "cePing",
"url": url,
"error": err.Error(),
})
return
}
// 解析ping输出
result := parsePingOutput(string(output), url)
c.JSON(200, result)
}
func parsePingOutput(output, url string) map[string]interface{} {
result := map[string]interface{}{
"type": "cePing",
"url": url,
"ip": "",
}
// 解析IP地址
lines := strings.Split(output, "\n")
for _, line := range lines {
if strings.Contains(line, "PING") {
// 提取IP地址
parts := strings.Fields(line)
for _, part := range parts {
if net.ParseIP(part) != nil {
result["ip"] = part
break
}
}
}
if strings.Contains(line, "packets transmitted") {
// 解析丢包率
parts := strings.Fields(line)
for i, part := range parts {
if part == "packet" && i+2 < len(parts) {
if loss, err := strconv.ParseFloat(strings.Trim(parts[i+1], "%"), 64); err == nil {
result["packets_losrat"] = loss
}
}
}
}
if strings.Contains(line, "min/avg/max") {
// 解析延迟统计
parts := strings.Fields(line)
for _, part := range parts {
if strings.Contains(part, "/") {
times := strings.Split(part, "/")
if len(times) >= 3 {
if min, err := strconv.ParseFloat(times[0], 64); err == nil {
result["time_min"] = min
}
if avg, err := strconv.ParseFloat(times[1], 64); err == nil {
result["time_avg"] = avg
}
if max, err := strconv.ParseFloat(times[2], 64); err == nil {
result["time_max"] = max
}
}
}
}
}
}
return result
}

View File

@@ -0,0 +1,59 @@
package handler
import (
"net"
"strconv"
"strings"
"time"
"github.com/gin-gonic/gin"
)
func handleSocket(c *gin.Context, url string, params map[string]interface{}) {
// 解析host:port格式
parts := strings.Split(url, ":")
if len(parts) != 2 {
c.JSON(200, gin.H{
"type": "ceSocket",
"url": url,
"error": "格式错误,需要 host:port",
})
return
}
host := parts[0]
portStr := parts[1]
port, err := strconv.Atoi(portStr)
if err != nil {
c.JSON(200, gin.H{
"type": "ceSocket",
"url": url,
"error": "端口格式错误",
})
return
}
// 执行TCP连接测试
conn, err := net.DialTimeout("tcp", net.JoinHostPort(host, portStr), 5*time.Second)
if err != nil {
c.JSON(200, gin.H{
"type": "ceSocket",
"url": url,
"host": host,
"port": port,
"result": "false",
"error": err.Error(),
})
return
}
defer conn.Close()
c.JSON(200, gin.H{
"type": "ceSocket",
"url": url,
"host": host,
"port": port,
"result": "true",
})
}

View File

@@ -0,0 +1,63 @@
package handler
import (
"net"
"strconv"
"strings"
"time"
"github.com/gin-gonic/gin"
)
func handleTCPing(c *gin.Context, url string, params map[string]interface{}) {
// 解析host:port格式
parts := strings.Split(url, ":")
if len(parts) != 2 {
c.JSON(200, gin.H{
"type": "ceTCPing",
"url": url,
"error": "格式错误,需要 host:port",
})
return
}
host := parts[0]
portStr := parts[1]
port, err := strconv.Atoi(portStr)
if err != nil {
c.JSON(200, gin.H{
"type": "ceTCPing",
"url": url,
"error": "端口格式错误",
})
return
}
// 执行TCP连接测试
start := time.Now()
conn, err := net.DialTimeout("tcp", net.JoinHostPort(host, portStr), 5*time.Second)
latency := time.Since(start).Milliseconds()
if err != nil {
c.JSON(200, gin.H{
"type": "ceTCPing",
"url": url,
"host": host,
"port": port,
"latency": -1,
"error": err.Error(),
})
return
}
defer conn.Close()
c.JSON(200, gin.H{
"type": "ceTCPing",
"url": url,
"host": host,
"port": port,
"latency": latency,
"success": true,
})
}

49
internal/handler/test.go Normal file
View File

@@ -0,0 +1,49 @@
package handler
import (
"net/http"
"github.com/gin-gonic/gin"
)
// HandleTest 统一测试接口
func HandleTest(c *gin.Context) {
var req struct {
Type string `json:"type" binding:"required"`
URL string `json:"url" binding:"required"`
Params map[string]interface{} `json:"params"`
}
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
// 根据类型分发到不同的处理器
switch req.Type {
case "ceGet":
handleGet(c, req.URL, req.Params)
case "cePost":
handlePost(c, req.URL, req.Params)
case "cePing":
handlePing(c, req.URL, req.Params)
case "ceDns":
handleDns(c, req.URL, req.Params)
case "ceTrace":
handleTrace(c, req.URL, req.Params)
case "ceSocket":
handleSocket(c, req.URL, req.Params)
case "ceTCPing":
handleTCPing(c, req.URL, req.Params)
case "ceFindPing":
handleFindPing(c, req.URL, req.Params)
default:
c.JSON(http.StatusBadRequest, gin.H{"error": "不支持的测试类型"})
}
}
// HandleHealth 健康检查
func HandleHealth(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"status": "ok"})
}

39
internal/handler/trace.go Normal file
View File

@@ -0,0 +1,39 @@
package handler
import (
"os/exec"
"strings"
"github.com/gin-gonic/gin"
)
func handleTrace(c *gin.Context, url string, params map[string]interface{}) {
// 执行traceroute命令
cmd := exec.Command("traceroute", url)
output, err := cmd.CombinedOutput()
if err != nil {
c.JSON(200, gin.H{
"type": "ceTrace",
"url": url,
"error": err.Error(),
})
return
}
// 解析输出
lines := strings.Split(string(output), "\n")
traceResult := make([]string, 0)
for _, line := range lines {
line = strings.TrimSpace(line)
if line != "" {
traceResult = append(traceResult, line)
}
}
c.JSON(200, gin.H{
"type": "ceTrace",
"url": url,
"trace_result": traceResult,
})
}