pubsub.go 5.1 KB
package redis

import (
	"sync"

	"github.com/garyburd/redigo/redis"
	//"math/rand"
	"errors"
	//"sync"
	"fmt"
	"time"
)

type subCallBack interface {
	OnMessage(channnel string, msg []byte)
	OnPMessage(pattern string, channel string, msg []byte)
	OnSubscription(kind string, channel string, count int)
	OnError(err error)
}

type PubSubClient struct {
	m_host        string
	db_index      uint32
	password      string
	m_pool        *redis.Pool
	m_conn        redis.PubSubConn
	m_lock        sync.Mutex
	m_subCallback subCallBack
	m_subList     map[string]bool
	m_psubList    map[string]bool
}

func NewPubSubClient(host string, password string, db uint32) *PubSubClient {
	client := new(PubSubClient)
	client.m_host = host
	client.db_index = db
	client.password = password
	client.m_subList = make(map[string]bool)
	client.m_psubList = make(map[string]bool)
	client.m_pool = &redis.Pool{
		MaxIdle: 64,
		//MaxActive:   1,
		IdleTimeout: 60 * time.Second,
		TestOnBorrow: func(c redis.Conn, t time.Time) error {
			_, err := c.Do("PING")

			return err
		},
		Dial: func() (redis.Conn, error) {
			c, err := redis.Dial("tcp", host, redis.DialDatabase(int(db)), redis.DialPassword(password))
			if err != nil {
				return nil, err
			}
			return c, err
		},
	}

	conn, err := redis.Dial("tcp", host, redis.DialDatabase(int(db)), redis.DialPassword(password))
	if err != nil {
		fmt.Printf("error dial:%v", err)
		return nil
	}
	client.m_conn = redis.PubSubConn{Conn: conn}
	return client
}

func (client *PubSubClient) Publish(channel, value interface{}) error {
	conn := client.m_pool.Get()
	defer conn.Close()
	if _, err := conn.Do("PUBLISH", channel, value); err != nil {
		return err
	}
	return nil

}

func (client *PubSubClient) Subscribe(channel string) error {
	client.m_lock.Lock()
	defer client.m_lock.Unlock()
	if err := client.m_conn.Subscribe(channel); err != nil {
		return err
	}
	client.m_subList[channel] = true
	return nil
}

func (client *PubSubClient) Unsubscribe(channel string) error {
	//logger.DEBUG("Unsubscribe, channName:%s\n", channel)
	client.m_lock.Lock()
	defer client.m_lock.Unlock()
	if channel == "" {
		if err := client.m_conn.Unsubscribe(); err != nil {
			return err
		}
		// clear all element in this map
		for k := range client.m_subList {
			delete(client.m_subList, k)
		}
	} else {
		if err := client.m_conn.Unsubscribe(channel); err != nil {
			return err
		}
		delete(client.m_subList, channel)
	}
	return nil
}

func (client *PubSubClient) PSubscribe(channel string) error {
	client.m_lock.Lock()
	defer client.m_lock.Unlock()
	if err := client.m_conn.PSubscribe(channel); err != nil {
		return err
	}
	client.m_psubList[channel] = true
	return nil
}

func (client *PubSubClient) PUnsubscribe(channel string) error {
	client.m_lock.Lock()
	defer client.m_lock.Unlock()
	if channel == "" {
		if err := client.m_conn.PUnsubscribe(); err != nil {
			return err
		}
		// clear all element in this map
		for k := range client.m_subList {
			delete(client.m_psubList, k)
		}
	} else {
		if err := client.m_conn.PUnsubscribe(channel); err != nil {
			return err
		}
		delete(client.m_subList, channel)
	}
	return nil
}

func (client *PubSubClient) SetSubCallback(cb subCallBack) {
	client.m_subCallback = cb
}

func (client *PubSubClient) reConnect() {

	//1. first close old connection
	client.m_conn.Close()

	//2. create a new connection
	for {
		//		conn, err := redis.Dial("tcp", client.m_host)
		conn, err := redis.Dial("tcp", client.m_host, redis.DialDatabase(int(client.db_index)), redis.DialPassword(client.password))
		if err == nil {
			client.m_conn = redis.PubSubConn{Conn: conn}
			return
		}
		time.Sleep(time.Second)
		//logger.DEBUG("reConnecting\n")
	}
}

func (client *PubSubClient) reSubscribe() {
	for key := range client.m_subList {
		client.Subscribe(key)
	}
	for key := range client.m_psubList {
		client.PSubscribe(key)
	}
}

func (client *PubSubClient) StartSubscribe(cb subCallBack) error {
	if cb == nil {
		return errors.New("subCallback is not set!")
	}
	client.m_subCallback = cb
	go func() {
		//exit_loop:
		for {
			switch n := client.m_conn.Receive().(type) {
			case redis.Message:
				//logger.DEBUG("Message: chann[%s] data[%s]\n", n.Channel, n.Data)
				if client.m_subCallback.OnMessage != nil {
					client.m_subCallback.OnMessage(n.Channel, n.Data)
				}
			case redis.PMessage:
				//logger.DEBUG("PMessage: pattern[%s] chann[%s] data[%s]\n", n.Pattern, n.Channel, n.Data)
				if client.m_subCallback.OnPMessage != nil {
					client.m_subCallback.OnPMessage(n.Pattern, n.Channel, n.Data)
				}
			case redis.Subscription:
				//logger.DEBUG("Subscription: %s %s %d\n", n.Kind, n.Channel, n.Count)
				if client.m_subCallback.OnSubscription != nil {
					client.m_subCallback.OnSubscription(n.Kind, n.Channel, n.Count)
				}
				if n.Count == 0 {
					//logger.DEBUG("exit ddd: \n")
					//break exit_loop
				}
			case error:
				if client.m_subCallback.OnError != nil {
					client.m_subCallback.OnError(n)
				}
				//logger.DEBUG("error: %v\n", n)
				client.reConnect()
				client.reSubscribe()
				//client.m_conn = redis.PubSubConn{Conn: m_psPool.Get()}
				//break exit_loop
			}
		}
		//logger.DEBUG("exit: \n")
	}()

	return nil
}