server.go 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390
  1. // Copyright 2023 The frp Authors
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. package ssh
  15. import (
  16. "context"
  17. "encoding/binary"
  18. "errors"
  19. "fmt"
  20. "net"
  21. "slices"
  22. "strings"
  23. "sync"
  24. "time"
  25. libio "github.com/fatedier/golib/io"
  26. "github.com/spf13/cobra"
  27. flag "github.com/spf13/pflag"
  28. "golang.org/x/crypto/ssh"
  29. "github.com/fatedier/frp/client/proxy"
  30. "github.com/fatedier/frp/pkg/config"
  31. v1 "github.com/fatedier/frp/pkg/config/v1"
  32. "github.com/fatedier/frp/pkg/msg"
  33. "github.com/fatedier/frp/pkg/util/log"
  34. netpkg "github.com/fatedier/frp/pkg/util/net"
  35. "github.com/fatedier/frp/pkg/util/util"
  36. "github.com/fatedier/frp/pkg/util/xlog"
  37. "github.com/fatedier/frp/pkg/virtual"
  38. )
  39. const (
  40. // https://datatracker.ietf.org/doc/html/rfc4254#page-16
  41. ChannelTypeServerOpenChannel = "forwarded-tcpip"
  42. RequestTypeForward = "tcpip-forward"
  43. )
  44. type tcpipForward struct {
  45. Host string
  46. Port uint32
  47. }
  48. // https://datatracker.ietf.org/doc/html/rfc4254#page-16
  49. type forwardedTCPPayload struct {
  50. Addr string
  51. Port uint32
  52. OriginAddr string
  53. OriginPort uint32
  54. }
  55. type TunnelServer struct {
  56. underlyingConn net.Conn
  57. sshConn *ssh.ServerConn
  58. sc *ssh.ServerConfig
  59. firstChannel ssh.Channel
  60. vc *virtual.Client
  61. peerServerListener *netpkg.InternalListener
  62. doneCh chan struct{}
  63. closeDoneChOnce sync.Once
  64. }
  65. func NewTunnelServer(conn net.Conn, sc *ssh.ServerConfig, peerServerListener *netpkg.InternalListener) (*TunnelServer, error) {
  66. s := &TunnelServer{
  67. underlyingConn: conn,
  68. sc: sc,
  69. peerServerListener: peerServerListener,
  70. doneCh: make(chan struct{}),
  71. }
  72. return s, nil
  73. }
  74. func (s *TunnelServer) Run() error {
  75. sshConn, channels, requests, err := ssh.NewServerConn(s.underlyingConn, s.sc)
  76. if err != nil {
  77. return err
  78. }
  79. s.sshConn = sshConn
  80. addr, extraPayload, err := s.waitForwardAddrAndExtraPayload(channels, requests, 3*time.Second)
  81. if err != nil {
  82. return err
  83. }
  84. clientCfg, pc, helpMessage, err := s.parseClientAndProxyConfigurer(addr, extraPayload)
  85. if err != nil {
  86. if errors.Is(err, flag.ErrHelp) {
  87. s.writeToClient(helpMessage)
  88. return nil
  89. }
  90. s.writeToClient(err.Error())
  91. return fmt.Errorf("parse flags from ssh client error: %v", err)
  92. }
  93. if err := clientCfg.Complete(); err != nil {
  94. s.writeToClient(fmt.Sprintf("failed to complete client config: %v", err))
  95. return fmt.Errorf("complete client config error: %v", err)
  96. }
  97. if sshConn.Permissions != nil {
  98. clientCfg.User = util.EmptyOr(sshConn.Permissions.Extensions["user"], clientCfg.User)
  99. }
  100. pc.Complete(clientCfg.User)
  101. vc, err := virtual.NewClient(virtual.ClientOptions{
  102. Common: clientCfg,
  103. Spec: &msg.ClientSpec{
  104. Type: "ssh-tunnel",
  105. // If ssh does not require authentication, then the virtual client needs to authenticate through a token.
  106. // Otherwise, once ssh authentication is passed, the virtual client does not need to authenticate again.
  107. AlwaysAuthPass: !s.sc.NoClientAuth,
  108. },
  109. HandleWorkConnCb: func(base *v1.ProxyBaseConfig, workConn net.Conn, m *msg.StartWorkConn) bool {
  110. // join workConn and ssh channel
  111. c, err := s.openConn(addr)
  112. if err != nil {
  113. log.Tracef("open conn error: %v", err)
  114. workConn.Close()
  115. return false
  116. }
  117. libio.Join(c, workConn)
  118. return false
  119. },
  120. })
  121. if err != nil {
  122. return err
  123. }
  124. s.vc = vc
  125. // transfer connection from virtual client to server peer listener
  126. go func() {
  127. l := s.vc.PeerListener()
  128. for {
  129. conn, err := l.Accept()
  130. if err != nil {
  131. return
  132. }
  133. _ = s.peerServerListener.PutConn(conn)
  134. }
  135. }()
  136. xl := xlog.New().AddPrefix(xlog.LogPrefix{Name: "sshVirtualClient", Value: "sshVirtualClient", Priority: 100})
  137. ctx := xlog.NewContext(context.Background(), xl)
  138. go func() {
  139. vcErr := s.vc.Run(ctx)
  140. if vcErr != nil {
  141. s.writeToClient(vcErr.Error())
  142. }
  143. // If vc.Run returns, it means that the virtual client has been closed, and the ssh tunnel connection should be closed.
  144. // One scenario is that the virtual client exits due to login failure.
  145. s.closeDoneChOnce.Do(func() {
  146. _ = sshConn.Close()
  147. close(s.doneCh)
  148. })
  149. }()
  150. s.vc.UpdateProxyConfigurer([]v1.ProxyConfigurer{pc})
  151. if ps, err := s.waitProxyStatusReady(pc.GetBaseConfig().Name, time.Second); err != nil {
  152. s.writeToClient(err.Error())
  153. log.Warnf("wait proxy status ready error: %v", err)
  154. } else {
  155. // success
  156. s.writeToClient(createSuccessInfo(clientCfg.User, pc, ps))
  157. _ = sshConn.Wait()
  158. }
  159. s.vc.Close()
  160. log.Tracef("ssh tunnel connection from %v closed", sshConn.RemoteAddr())
  161. s.closeDoneChOnce.Do(func() {
  162. _ = sshConn.Close()
  163. close(s.doneCh)
  164. })
  165. return nil
  166. }
  167. func (s *TunnelServer) writeToClient(data string) {
  168. if s.firstChannel == nil {
  169. return
  170. }
  171. _, _ = s.firstChannel.Write([]byte(data + "\n"))
  172. }
  173. func (s *TunnelServer) waitForwardAddrAndExtraPayload(
  174. channels <-chan ssh.NewChannel,
  175. requests <-chan *ssh.Request,
  176. timeout time.Duration,
  177. ) (*tcpipForward, string, error) {
  178. addrCh := make(chan *tcpipForward, 1)
  179. extraPayloadCh := make(chan string, 1)
  180. // get forward address
  181. go func() {
  182. addrGot := false
  183. for req := range requests {
  184. if req.Type == RequestTypeForward && !addrGot {
  185. payload := tcpipForward{}
  186. if err := ssh.Unmarshal(req.Payload, &payload); err != nil {
  187. return
  188. }
  189. addrGot = true
  190. addrCh <- &payload
  191. }
  192. if req.WantReply {
  193. _ = req.Reply(true, nil)
  194. }
  195. }
  196. }()
  197. // get extra payload
  198. go func() {
  199. for newChannel := range channels {
  200. // extraPayload will send to extraPayloadCh
  201. go s.handleNewChannel(newChannel, extraPayloadCh)
  202. }
  203. }()
  204. var (
  205. addr *tcpipForward
  206. extraPayload string
  207. )
  208. timer := time.NewTimer(timeout)
  209. defer timer.Stop()
  210. for {
  211. select {
  212. case v := <-addrCh:
  213. addr = v
  214. case extra := <-extraPayloadCh:
  215. extraPayload = extra
  216. case <-timer.C:
  217. return nil, "", fmt.Errorf("get addr and extra payload timeout")
  218. }
  219. if addr != nil && extraPayload != "" {
  220. break
  221. }
  222. }
  223. return addr, extraPayload, nil
  224. }
  225. func (s *TunnelServer) parseClientAndProxyConfigurer(_ *tcpipForward, extraPayload string) (*v1.ClientCommonConfig, v1.ProxyConfigurer, string, error) {
  226. helpMessage := ""
  227. cmd := &cobra.Command{
  228. Use: "ssh v0@{address} [command]",
  229. Short: "ssh v0@{address} [command]",
  230. Run: func(*cobra.Command, []string) {},
  231. }
  232. cmd.SetGlobalNormalizationFunc(config.WordSepNormalizeFunc)
  233. args := strings.Split(extraPayload, " ")
  234. if len(args) < 1 {
  235. return nil, nil, helpMessage, fmt.Errorf("invalid extra payload")
  236. }
  237. proxyType := strings.TrimSpace(args[0])
  238. supportTypes := []string{"tcp", "http", "https", "tcpmux", "stcp"}
  239. if !slices.Contains(supportTypes, proxyType) {
  240. return nil, nil, helpMessage, fmt.Errorf("invalid proxy type: %s, support types: %v", proxyType, supportTypes)
  241. }
  242. pc := v1.NewProxyConfigurerByType(v1.ProxyType(proxyType))
  243. if pc == nil {
  244. return nil, nil, helpMessage, fmt.Errorf("new proxy configurer error")
  245. }
  246. config.RegisterProxyFlags(cmd, pc, config.WithSSHMode())
  247. clientCfg := v1.ClientCommonConfig{}
  248. config.RegisterClientCommonConfigFlags(cmd, &clientCfg, config.WithSSHMode())
  249. cmd.InitDefaultHelpCmd()
  250. if err := cmd.ParseFlags(args); err != nil {
  251. if errors.Is(err, flag.ErrHelp) {
  252. helpMessage = cmd.UsageString()
  253. }
  254. return nil, nil, helpMessage, err
  255. }
  256. // if name is not set, generate a random one
  257. if pc.GetBaseConfig().Name == "" {
  258. id, err := util.RandIDWithLen(8)
  259. if err != nil {
  260. return nil, nil, helpMessage, fmt.Errorf("generate random id error: %v", err)
  261. }
  262. pc.GetBaseConfig().Name = fmt.Sprintf("sshtunnel-%s-%s", proxyType, id)
  263. }
  264. return &clientCfg, pc, helpMessage, nil
  265. }
  266. func (s *TunnelServer) handleNewChannel(channel ssh.NewChannel, extraPayloadCh chan string) {
  267. ch, reqs, err := channel.Accept()
  268. if err != nil {
  269. return
  270. }
  271. if s.firstChannel == nil {
  272. s.firstChannel = ch
  273. }
  274. go s.keepAlive(ch)
  275. for req := range reqs {
  276. if req.WantReply {
  277. _ = req.Reply(true, nil)
  278. }
  279. if req.Type != "exec" || len(req.Payload) <= 4 {
  280. continue
  281. }
  282. end := 4 + binary.BigEndian.Uint32(req.Payload[:4])
  283. if len(req.Payload) < int(end) {
  284. continue
  285. }
  286. extraPayload := string(req.Payload[4:end])
  287. select {
  288. case extraPayloadCh <- extraPayload:
  289. default:
  290. }
  291. }
  292. }
  293. func (s *TunnelServer) keepAlive(ch ssh.Channel) {
  294. tk := time.NewTicker(time.Second * 30)
  295. defer tk.Stop()
  296. for {
  297. select {
  298. case <-tk.C:
  299. _, err := ch.SendRequest("heartbeat", false, nil)
  300. if err != nil {
  301. return
  302. }
  303. case <-s.doneCh:
  304. return
  305. }
  306. }
  307. }
  308. func (s *TunnelServer) openConn(addr *tcpipForward) (net.Conn, error) {
  309. payload := forwardedTCPPayload{
  310. Addr: addr.Host,
  311. Port: addr.Port,
  312. // Note: Here is just for compatibility, not the real source address.
  313. OriginAddr: addr.Host,
  314. OriginPort: addr.Port,
  315. }
  316. channel, reqs, err := s.sshConn.OpenChannel(ChannelTypeServerOpenChannel, ssh.Marshal(&payload))
  317. if err != nil {
  318. return nil, fmt.Errorf("open ssh channel error: %v", err)
  319. }
  320. go ssh.DiscardRequests(reqs)
  321. conn := netpkg.WrapReadWriteCloserToConn(channel, s.underlyingConn)
  322. return conn, nil
  323. }
  324. func (s *TunnelServer) waitProxyStatusReady(name string, timeout time.Duration) (*proxy.WorkingStatus, error) {
  325. ticker := time.NewTicker(100 * time.Millisecond)
  326. defer ticker.Stop()
  327. timer := time.NewTimer(timeout)
  328. defer timer.Stop()
  329. statusExporter := s.vc.Service().StatusExporter()
  330. for {
  331. select {
  332. case <-ticker.C:
  333. ps, ok := statusExporter.GetProxyStatus(name)
  334. if !ok {
  335. continue
  336. }
  337. switch ps.Phase {
  338. case proxy.ProxyPhaseRunning:
  339. return ps, nil
  340. case proxy.ProxyPhaseStartErr, proxy.ProxyPhaseClosed:
  341. return ps, errors.New(ps.Err)
  342. }
  343. case <-timer.C:
  344. return nil, fmt.Errorf("wait proxy status ready timeout")
  345. case <-s.doneCh:
  346. return nil, fmt.Errorf("ssh tunnel server closed")
  347. }
  348. }
  349. }