diff --git a/cmd/stock/main.go b/cmd/stock/main.go index 19654b0..b72fc53 100644 --- a/cmd/stock/main.go +++ b/cmd/stock/main.go @@ -11,7 +11,6 @@ import ( "net/http" "stock/cfg" "stock/fund" - "stock/module" "stock/msg" "stock/stock" "stock/user" @@ -89,7 +88,10 @@ func main() { } fund.Clear() - user.ForEachUser(func(u module.IUser) bool { + user.ForEachUser(func(u *user.User) bool { + if u.IsStop() { + return true + } codes := u.Codes(true) stk := fund.NewFundArg(codes...) err = wxgzh.Send(u.OpenID(), stk) @@ -153,7 +155,10 @@ func main() { continue } - user.ForEachUser(func(u module.IUser) bool { + user.ForEachUser(func(u *user.User) bool { + if u.IsStop() { + return true + } codes := u.Codes(false) stk, err := stock.GetStocks(codes...) if err != nil { diff --git a/fund/fund.go b/fund/fund.go index cbb6a14..1bcabc3 100644 --- a/fund/fund.go +++ b/fund/fund.go @@ -21,6 +21,7 @@ import ( "regexp" "strconv" "sync" + "time" ) const ( @@ -40,6 +41,7 @@ type fund struct { EstimateVal string `json:"gsz"` RisePer string `json:"gszzl"` UpdateTime string `json:"gztime"` + CheckTime time.Time } func (f *fund) Name() string { @@ -64,11 +66,21 @@ func (f *fund) Update() error { f.EstimateVal = ff.EstimateVal f.RisePer = ff.RisePer f.UpdateTime = ff.UpdateTime + f.CheckTime = time.Now() + } return err } func (f *fund) Msg() string { + now := time.Now() + h := now.Hour() + if h > 9 && h < 15 && now.Sub(f.CheckTime) > time.Minute*5 { + err := f.Update() + if err != nil { + logx.Error(err) + } + } var rise string last, err1 := strconv.ParseFloat(f.UnitVal, 64) cur, err2 := strconv.ParseFloat(f.EstimateVal, 64) diff --git a/module/user.go b/module/user.go deleted file mode 100644 index 667e396..0000000 --- a/module/user.go +++ /dev/null @@ -1,21 +0,0 @@ -/** - * @Author: jager - * @Email: lhj168os@gmail.com - * @File: user - * @Date: 2021/12/21 3:17 下午 - * @package: module - * @Version: v1.0.0 - * - * @Description: - * - */ - -package module - -type IUser interface { - OpenID() string - Codes(isFund bool) []string - HasSubscribed(isFund bool, code string) bool - Subscribe(isFund bool, codes ...string) - UnSubscribe(isFund bool, codes ...string) -} diff --git a/msg/msg_test.go b/msg/msg_test.go index a02d618..585db26 100644 --- a/msg/msg_test.go +++ b/msg/msg_test.go @@ -13,21 +13,8 @@ package msg import ( - xml2 "encoding/xml" - "fmt" "testing" - "time" ) func Test_Post(t *testing.T) { - - wMsg := &xml{ - ToUserName: "rMsg.FromUserName", - FromUserName: "rMsg.ToUserName", - CreateTime: int(time.Now().Unix()), - MsgType: "text", - Content: "收到,谢谢!", - } - bty, _ := xml2.Marshal(wMsg) - fmt.Println(string(bty)) } diff --git a/stock/stock.go b/stock/stock.go index f795962..bea2f00 100644 --- a/stock/stock.go +++ b/stock/stock.go @@ -160,7 +160,6 @@ func (s *stock) Msg() string { // ========================== func Init(codes ...string) error { - fds = &stocks{stkMap: map[string]*stock{}} return fds.AddCodes(codes...) } @@ -172,9 +171,15 @@ type stocks struct { func (sk *stocks) getStocks(codes ...string) *stocks { mx.RLock() defer mx.RUnlock() - var stks = &stocks{} + var stks = &stocks{ + stkMap: map[string]*stock{}, + } for _, code := range codes { - stks.stkMap[code] = sk.stkMap[code] + if kk, ok := sk.stkMap[code]; ok { + stks.stkMap[code] = kk + } else { + logx.Errorf("ErrStockCode=%s", code) + } } return stks } @@ -288,9 +293,6 @@ func (sk *stocks) Arg(openid string) map[string]interface{} { } func GetStock(code string) (*stock, error) { - if fds == nil { - fds = &stocks{stkMap: map[string]*stock{}} - } err := fds.AddCodes(code) if err != nil { return nil, err @@ -307,9 +309,6 @@ func GetStocks(codes ...string) (*stocks, error) { if len(codes) <= 0 { return nil, errcode.New(1, "股票代码为空") } - if fds == nil { - fds = &stocks{stkMap: map[string]*stock{}} - } err := fds.AddCodes(codes...) if err != nil { return nil, err diff --git a/stock/stock_test.go b/stock/stock_test.go index ac2bffa..87534e5 100644 --- a/stock/stock_test.go +++ b/stock/stock_test.go @@ -6,7 +6,7 @@ import ( ) func Test_newStocks(t *testing.T) { - ss, err := NewStocks() + ss, err := GetStocks("600905") if err != nil { fmt.Println(err) return diff --git a/user/cache.go b/user/cache.go index 9b1a82b..7fd24f2 100644 --- a/user/cache.go +++ b/user/cache.go @@ -15,7 +15,6 @@ package user import ( "github.com/jageros/hawox/attribute" "github.com/jageros/hawox/logx" - "stock/module" "sync" ) @@ -66,7 +65,7 @@ func GetUser(openId string) (*User, error) { return us, nil } -func ForEachUser(f func(u module.IUser) bool) { +func ForEachUser(f func(u *User) bool) { mx.Lock() defer mx.Unlock() for _, u := range users { @@ -78,7 +77,10 @@ func ForEachUser(f func(u module.IUser) bool) { func Codes(isFund bool) []string { var codes = map[string]struct{}{} - ForEachUser(func(u module.IUser) bool { + ForEachUser(func(u *User) bool { + if u.IsStop() { + return true + } cds := u.Codes(isFund) for _, cd := range cds { if _, ok := codes[cd]; ok { diff --git a/user/user.go b/user/user.go index cc0fc1f..96a8483 100644 --- a/user/user.go +++ b/user/user.go @@ -39,6 +39,25 @@ func (u *User) OpenID() string { return u.attr.GetAttrID().(string) } +func (u *User) Stop() { + u.mx.Lock() + u.attr.SetBool("stop", true) + u.mx.Unlock() +} + +func (u *User) Start() { + u.mx.Lock() + u.attr.SetBool("stop", false) + u.mx.Unlock() +} + +func (u *User) IsStop() bool { + u.mx.RLock() + stop := u.attr.GetBool("stop") + u.mx.RUnlock() + return stop +} + func (u *User) Codes(isFund bool) []string { u.mx.RLock() defer u.mx.RUnlock() @@ -52,7 +71,10 @@ func (u *User) Codes(isFund bool) []string { } var codes []string attr.ForEachKey(func(key string) bool { - codes = append(codes, key) + sAttr := attr.GetMapAttr(key) + if sAttr.GetBool("notify") { + codes = append(codes, key) + } return true }) return codes diff --git a/wxgzh/wxgzh.go b/wxgzh/wxgzh.go index 29ea459..7e8f5b7 100644 --- a/wxgzh/wxgzh.go +++ b/wxgzh/wxgzh.go @@ -20,7 +20,6 @@ import ( "github.com/jageros/hawox/logx" "net/http" "stock/fund" - "stock/module" "stock/stock" "stock/user" "strings" @@ -116,19 +115,19 @@ func Send(openId string, stk IArg) error { return send(openId, stk, true) } -func SendAll(stk IArg) error { - user.ForEachUser(func(u module.IUser) bool { - if u.HasSubscribed(false, "") { - err := Send(u.OpenID(), stk) - if err != nil { - logx.Error(err) - return false - } - } - return true - }) - return nil -} +//func SendAll(stk IArg) error { +// user.ForEachUser(func(u *user.User) bool { +// if u.HasSubscribed(false, "") { +// err := Send(u.OpenID(), stk) +// if err != nil { +// logx.Error(err) +// return false +// } +// } +// return true +// }) +// return nil +//} type rData struct { ToUserName string `xml:"ToUserName"` @@ -137,6 +136,7 @@ type rData struct { MsgType string `xml:"MsgType"` Content string `xml:"Content"` MsgID int64 `xml:"MsgId"` + Event string `xml:"Event"` } type xml struct { @@ -163,18 +163,52 @@ func Handle(c *gin.Context) { CreateTime: time.Now().Unix(), Content: "查询股票:=st股票代码(例如:=st600905)\n查询基金:=fd基金代码(例如:=fd161725)\n\n" + "订阅股票:+st股票代码(例如:+st600905)\n订阅基金:+fd基金代码(例如:+fd161725)\n\n" + - "取消订阅股票:-st股票代码(例如:-st600905)\n取消订阅基金:-fd基金代码(例如:-fd161725)", + "取消订阅股票:-st股票代码(例如:-st600905)\n取消订阅基金:-fd基金代码(例如:-fd161725)\n\n" + + "查询已订阅股票:lst\n" + + "查询已订阅基金:lfd\n" + + "停止通知:stop\n" + + "启动通知:start\n", + } + + u, err := user.GetUser(rMsg.FromUserName) + if err != nil { + wMsg.Content = err.Error() + rMsg.MsgType = "" } if rMsg.MsgType == "text" { - u, err := user.GetUser(rMsg.FromUserName) - if err != nil { - c.String(http.StatusOK, err.Error()) - c.Abort() - return - } switch { + case strings.ToLower(rMsg.Content) == "stop": + u.Stop() + wMsg.Content = "已停止通知!" + case strings.ToLower(rMsg.Content) == "start": + u.Start() + wMsg.Content = "已启动通知!" + + case strings.ToLower(rMsg.Content) == "lst": + codes := u.Codes(false) + stks, err := stock.GetStocks(codes...) + if err != nil { + wMsg.Content = "查询错误:\n" + err.Error() + } else { + msg := stks.Msg() + if msg == "" { + wMsg.Content = "您未订阅任何股票!" + } else { + wMsg.Content = "所有订阅股票信息:\n" + msg + } + } + + case strings.ToLower(rMsg.Content) == "lfd": + codes := u.Codes(true) + msg := fund.FundsMsg(codes...) + if msg == "" { + wMsg.Content = "您未订阅任何基金!" + } else { + wMsg.Content = "所有订阅基金信息:\n" + msg + } + case len(rMsg.Content) < 9: break @@ -227,5 +261,19 @@ func Handle(c *gin.Context) { } } + if rMsg.MsgType == "event" { + switch rMsg.Event { + case "subscribe": + logx.Infof("user=%s subscribe!", u.OpenID()) + u.Start() + case "unsubscribe": + logx.Infof("user=%s unsubscribe!", u.OpenID()) + u.Stop() + c.String(http.StatusOK, "success") + c.Abort() + return + } + } + c.XML(http.StatusOK, wMsg) }