Commit 2cd99719 authored by qiuqunfeng's avatar qiuqunfeng
Browse files

Enhance WAF API by adding CountAttackLogs endpoint and updating WafController...

Enhance WAF API by adding CountAttackLogs endpoint and updating WafController to support region-specific log counting. Refactor service methods to accommodate region code in attack log counting logic.
parent aab2db68
......@@ -9,9 +9,10 @@ import (
)
func SetWafApiServerRouter(e *gin.Engine, clusterClientManager *utils.ClusterClientManager, db *gorm.DB, gatewayUrl string, elasticClient *elastic.Client) {
wafController := controller.NewWafController(clusterClientManager, db, gatewayUrl, elasticClient)
wafController := controller.NewWafController(clusterClientManager, db, gatewayUrl, elasticClient, nil)
v2 := e.Group("api/v2/waf")
v2.GET("attack/log/list", wafController.ListAttackLogs)
v2.GET("attack/log/details", wafController.GetAttackLogDetails)
v2.GET("attack/log/rspPkg", wafController.GetAttackLogRsp)
v2.GET("attack/log/count", wafController.CountAttackLogs)
}
......@@ -11,7 +11,7 @@ import (
func SetWafRouter(e *gin.Engine, clusterClientManager *utils.ClusterClientManager, db *gorm.DB, gatewayUrl string, elasticClient *elastic.Client, regionUrlMap map[string]string) {
v1 := e.Group("v1/api/waf")
wafController := controller.NewWafController(clusterClientManager, db, gatewayUrl, elasticClient)
wafController := controller.NewWafController(clusterClientManager, db, gatewayUrl, elasticClient, regionUrlMap)
v1.GET("", wafController.Waf)
v1.GET("list", wafController.ListWafs)
v1.POST("/", wafController.CreateWaf)
......@@ -46,11 +46,12 @@ func SetWafRouter(e *gin.Engine, clusterClientManager *utils.ClusterClientManage
}
func SetApiWafRouter(e *gin.Engine, clusterClientManager *utils.ClusterClientManager, db *gorm.DB, gatewayUrl string, elasticClient *elastic.Client) {
v1 := e.Group("api/v2/waf")
// func SetApiWafRouter(e *gin.Engine, clusterClientManager *utils.ClusterClientManager, db *gorm.DB, gatewayUrl string, elasticClient *elastic.Client) {
// v1 := e.Group("api/v2/waf")
wafController := controller.NewWafController(clusterClientManager, db, gatewayUrl, elasticClient)
v1.GET("attack/log/list", wafController.ListAttackLogs)
v1.GET("attack/log/details", wafController.GetAttackLogDetails)
v1.GET("attack/log/rspPkg", wafController.GetAttackLogRsp)
}
// wafController := controller.NewWafController(clusterClientManager, db, gatewayUrl, elasticClient, nil)
// v1.GET("attack/log/list", wafController.ListAttackLogs)
// v1.GET("attack/log/details", wafController.GetAttackLogDetails)
// v1.GET("attack/log/rspPkg", wafController.GetAttackLogRsp)
// v1.GET("attack/log/count", wafController.CountAttackLogs)
// }
......@@ -18,11 +18,13 @@ import (
type WafController struct {
service service.Service
simpleProxy *service.SimpleProxy
}
func NewWafController(clusterClientManager *utils.ClusterClientManager, db *gorm.DB, gatewayUrl string, elasticClient *elastic.Client) *WafController {
func NewWafController(clusterClientManager *utils.ClusterClientManager, db *gorm.DB, gatewayUrl string, elasticClient *elastic.Client, regionUrlMap map[string]string) *WafController {
return &WafController{
service: service.NewWafService(clusterClientManager, db, gatewayUrl, elasticClient),
simpleProxy: service.NewSimpleProxy(regionUrlMap),
}
}
......@@ -40,6 +42,12 @@ func (c *WafController) Waf(ctx *gin.Context) {
utils.AssembleResponse(ctx, nil, err)
return
}
count, err := c.simpleProxy.CountAttackLogs(ctx1, regionCode, waf.ID)
if err != nil {
utils.AssembleResponse(ctx, nil, err)
return
}
waf.AttackNum = int(count)
resp := &utils.SingleRespData{
Item: waf,
}
......@@ -372,7 +380,23 @@ func (c *WafController) ListAttackLogs(ctx *gin.Context) {
}
func (c *WafController) CountAttackLogs(ctx *gin.Context) {
ctx1, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
regionCode := ctx.Query("region_code")
serviceID := ctx.Query("service_id")
serviceIDUint, err := strconv.ParseUint(serviceID, 10, 32)
if err != nil {
utils.AssembleResponse(ctx, nil, err)
return
}
count, err := c.service.CountAttackLogs(ctx1, regionCode, uint32(serviceIDUint))
if err != nil {
utils.AssembleResponse(ctx, nil, err)
return
}
utils.AssembleResponse(ctx, count, nil)
}
// getLimitAndOffset extracts pagination parameters from the context
......
......@@ -30,5 +30,5 @@ type Service interface {
GetBlackWhiteLists(ctx context.Context, query *MatchExprQueryOption, limit int, offset int) ([]MatcherExpr, int, error)
ListListenerHistory(ctx context.Context, query *WafListenerHistoryOption, limit, offset int) ([]model.WafListenerHistory, int, error)
ListAttackClasses(ctx context.Context, lang string) []AttackClasses
CountAttackLogs(ctx context.Context, serviceID uint32) int64
CountAttackLogs(ctx context.Context, regionCode string, serviceID uint32) (int64, error)
}
package service
import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"gitlab.com/tensorsecurity-rd/waf-console/internal/utils"
)
type SimpleProxy struct {
regionUrlMap map[string]string
}
func NewSimpleProxy(regionUrlMap map[string]string) *SimpleProxy {
return &SimpleProxy{
regionUrlMap: regionUrlMap,
}
}
func (s *SimpleProxy) CountAttackLogs(ctx context.Context, region_code string, serviceID uint32) (int64, error) {
remoteUrl, err := url.Parse(s.regionUrlMap[region_code])
if err != nil {
return 0, err
}
remoteUrl = remoteUrl.JoinPath("/api/v2/waf/attack/log/count")
remoteUrl.RawQuery = fmt.Sprintf("service_id=%d&region_code=%s", serviceID, region_code)
resp, err := http.Get(remoteUrl.String())
if err != nil {
return 0, err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return 0, fmt.Errorf("failed to count attack logs: %s", resp.Status)
}
body, err := io.ReadAll(resp.Body)
if err != nil {
return 0, err
}
var response utils.SuccessResponse
if err := json.Unmarshal(body, &response); err != nil {
return 0, err
}
if response.Code != "OK" {
return 0, fmt.Errorf("failed to count attack logs: %s", response.Message)
}
return response.Data.(int64), nil
}
......@@ -123,25 +123,25 @@ func (s *wafService) GetWaf(ctx context.Context, regionCode, namespace, gatewayN
}
// Count attack logs for current day
now := time.Now()
startOfDay := time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, now.Location())
endOfDay := startOfDay.Add(24 * time.Hour)
boolQuery := elastic.NewBoolQuery()
boolQuery.Must(elastic.NewTermQuery("service_id", wafService.ID))
boolQuery.Filter(elastic.NewRangeQuery("attack_time").
Gte(startOfDay.UnixMilli()).
Lt(endOfDay.UnixMilli()))
boolQuery.Filter(elastic.NewBoolQuery().MustNot(elastic.NewTermQuery("action", "pass")))
result, err := s.elasticClient.Count("waf-detections*").
Query(boolQuery).
Do(ctx)
if err != nil {
return nil, fmt.Errorf("failed to count attack logs: %v", err)
}
// now := time.Now()
// startOfDay := time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, now.Location())
// endOfDay := startOfDay.Add(24 * time.Hour)
// boolQuery := elastic.NewBoolQuery()
// boolQuery.Must(elastic.NewTermQuery("service_id", wafService.ID))
// boolQuery.Filter(elastic.NewRangeQuery("attack_time").
// Gte(startOfDay.UnixMilli()).
// Lt(endOfDay.UnixMilli()))
// boolQuery.Filter(elastic.NewBoolQuery().MustNot(elastic.NewTermQuery("action", "pass")))
// result, err := s.elasticClient.Count("waf-detections*").
// Query(boolQuery).
// Do(ctx)
// if err != nil {
// return nil, fmt.Errorf("failed to count attack logs: %v", err)
// }
wafService.AttackNum = int(result)
// wafService.AttackNum = int(result)
return &WafService{
GatewayName: wafService.GatewayName,
Mode: wafService.Mode,
......@@ -151,13 +151,14 @@ func (s *wafService) GetWaf(ctx context.Context, regionCode, namespace, gatewayN
}, nil
}
func (s *wafService) CountAttackLogs(ctx context.Context, serviceID uint32) int64 {
func (s *wafService) CountAttackLogs(ctx context.Context, regionCode string, serviceID uint32) (int64, error) {
now := time.Now()
startOfDay := time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, now.Location())
endOfDay := startOfDay.Add(24 * time.Hour)
boolQuery := elastic.NewBoolQuery()
boolQuery.Must(elastic.NewTermQuery("service_id", serviceID))
boolQuery.Filter(elastic.NewTermQuery("cluster_key", regionCode))
boolQuery.Filter(elastic.NewRangeQuery("attack_time").
Gte(startOfDay.UnixMilli()).
Lt(endOfDay.UnixMilli()))
......@@ -168,10 +169,10 @@ func (s *wafService) CountAttackLogs(ctx context.Context, serviceID uint32) int6
Do(ctx)
if err != nil {
log.Err(fmt.Errorf("failed to count attack logs: %v", err))
return 0
return 0, err
}
return result
return result, nil
}
func (s *wafService) ListWafs(ctx context.Context) ([]WafService, error) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment