diff --git a/src/sftp/sftp.h b/src/sftp/sftp.h index 66ab9f80b1bd33ac6cb798b11103c9ba891e74f0..58a70932b746d72111c7ccc5545f4b17c9568509 100644 --- a/src/sftp/sftp.h +++ b/src/sftp/sftp.h @@ -65,8 +65,9 @@ typedef struct sftp_rx_pkt { uint32_t cur; uint32_t sz; uint32_t used; + uint32_t len; uint8_t type; - uint8_t *data; + uint8_t data[]; } *sftp_rx_pkt_t; typedef struct sftp_string { @@ -91,7 +92,7 @@ struct sftp_file_attributes { }; typedef struct sftp_client_state { - bool (*send_cb)(sftp_tx_pkt_t *pkt, void *cb_data); + bool (*send_cb)(uint8_t *buf, size_t len, void *cb_data); xpevent_t recv_event; sftp_rx_pkt_t rxp; sftp_tx_pkt_t txp; @@ -111,12 +112,14 @@ uint32_t sftp_get32(sftp_rx_pkt_t pkt); uint32_t sftp_get64(sftp_rx_pkt_t pkt); sftp_str_t sftp_getstring(sftp_rx_pkt_t pkt, uint8_t **str); bool sftp_rx_pkt_append(sftp_rx_pkt_t *pkt, uint8_t *inbuf, uint32_t len); +bool sftp_tx_pkt_reset(sftp_tx_pkt_t *pktp); bool sftp_appendbyte(sftp_tx_pkt_t *pktp, uint8_t u8); bool sftp_append32(sftp_tx_pkt_t *pktp, uint32_t u32); bool sftp_append64(sftp_tx_pkt_t *pktp, uint64_t u); bool sftp_appendstring(sftp_tx_pkt_t *pktp, sftp_str_t s); void sftp_free_tx_pkt(sftp_tx_pkt_t pkt); void sftp_free_rx_pkt(sftp_rx_pkt_t pkt); +bool sftp_prep_tx_packet(sftp_tx_pkt_t pkt, uint8_t **buf, size_t *sz); /* sftp_str.c */ sftp_str_t sftp_strdup(const char *str); @@ -126,7 +129,7 @@ void free_sftp_str(sftp_str_t str); /* sftp_client.c */ void sftpc_finish(sftpc_state_t state); -sftpc_state_t sftpc_begin(bool (*send_cb)(sftp_tx_pkt_t *pkt, void *cb_data), void *cb_data); +sftpc_state_t sftpc_begin(bool (*send_cb)(uint8_t *buf, size_t len, void *cb_data), void *cb_data); bool sftpc_init(sftpc_state_t state); bool sftpc_recv(sftpc_state_t state, uint8_t *buf, uint32_t sz); diff --git a/src/sftp/sftp_client.c b/src/sftp/sftp_client.c index f5be0f84c7a9a2d2636d684585626f964452643a..c17749b6cf2530668f706486270cf186ad1e8322 100644 --- a/src/sftp/sftp_client.c +++ b/src/sftp/sftp_client.c @@ -23,7 +23,7 @@ sftpc_finish(sftpc_state_t state) } sftpc_state_t -sftpc_begin(bool (*send_cb)(sftp_tx_pkt_t *pkt, void *cb_data), void *cb_data) +sftpc_begin(bool (*send_cb)(uint8_t *buf, size_t len, void *cb_data), void *cb_data) { sftpc_state_t ret = (sftpc_state_t)malloc(sizeof(sftpc_state_t)); if (ret == NULL) @@ -46,24 +46,34 @@ sftpc_init(sftpc_state_t state) { assert(state); if (!state) - return false; + goto fail; assert(state->thread == pthread_self()); if (state->thread != pthread_self()) - return false; + goto fail; if (!sftp_appendbyte(&state->txp, SSH_FXP_INIT)) - return false; + goto fail; if (!sftp_append32(&state->txp, SFTP_VERSION)) - return false; + goto fail; + uint8_t *txbuf; + size_t txsz; + if (!sftp_prep_tx_packet(state->txp, &txbuf, &txsz)) + goto fail; + if (!state->send_cb(txbuf, txsz, state->cb_data)) + goto fail; + sftp_tx_pkt_reset(&state->txp); if (WaitForEvent(state->recv_event, INFINITE) != WAIT_OBJECT_0) - return false; + goto fail; if (state->rxp->type != SSH_FXP_VERSION) - return false; + goto fail; if (sftp_get32(state->rxp) != SFTP_VERSION) - return false; + goto fail; sftp_remove_packet(state->rxp); if (!sftp_have_full_pkt(state->rxp)) ResetEvent(state->recv_event); return true; +fail: + sftp_tx_pkt_reset(&state->txp); + return false; } bool diff --git a/src/sftp/sftp_pkt.c b/src/sftp/sftp_pkt.c index 01c78696b326c2c68be910901462320dd2894a87..064ab91a0f68827f2e03e53e44b69a8c48b5e3c8 100644 --- a/src/sftp/sftp_pkt.c +++ b/src/sftp/sftp_pkt.c @@ -81,7 +81,7 @@ sftp_pkt_sz(sftp_rx_pkt_t pkt) if (!pkt) return false; assert(sftp_have_pkt_sz(pkt)); - return BE_INT32(pkt->used); + return BE_INT32(pkt->len); } uint8_t @@ -114,12 +114,12 @@ sftp_remove_packet(sftp_rx_pkt_t pkt) if (!pkt) return; uint32_t sz = sftp_pkt_sz(pkt); - assert(pkt->sz <= pkt->used); + assert(sz <= pkt->used); uint32_t newsz = pkt->used - sz - sizeof(uint32_t); - uint8_t *src = (uint8_t *)&pkt->sz; + uint8_t *src = (uint8_t *)&pkt->len; src += sizeof(uint32_t); - src += pkt->used; - memmove(&pkt->sz, src, newsz); + src += sz; + memmove(&pkt->len, src, newsz); pkt->used = newsz; // TODO: realloc() smaller? return; @@ -182,23 +182,22 @@ bool sftp_rx_pkt_append(sftp_rx_pkt_t *pktp, uint8_t *inbuf, uint32_t len) { assert(pktp); + if (!pktp) + return false; size_t old_sz; size_t new_sz; uint32_t old_used; - assert(pktp); - if (!pktp) - return false; sftp_rx_pkt_t pkt = *pktp; if (pkt == NULL) { old_sz = 0; old_used = 0; - new_sz = offsetof(struct sftp_rx_pkt, used) + len; + new_sz = offsetof(struct sftp_rx_pkt, len) + len; } else { old_used = pkt->used; old_sz = pkt->sz; - new_sz = offsetof(struct sftp_rx_pkt, used) + pkt->used + len; + new_sz = offsetof(struct sftp_rx_pkt, len) + pkt->used + len; } if (new_sz > old_sz) { if (new_sz % SFTP_MIN_PACKET_ALLOC) @@ -211,9 +210,10 @@ sftp_rx_pkt_append(sftp_rx_pkt_t *pktp, uint8_t *inbuf, uint32_t len) } *pktp = new_buf; pkt = *pktp; + pkt->sz = new_sz; } - memcpy(&((uint8_t *)&(pkt->used))[old_used], inbuf, len); - pkt->used += len; + memcpy(&((uint8_t *)&(pkt->len))[old_used], inbuf, len); + pkt->used = old_used + len; return true; } @@ -230,7 +230,6 @@ grow_tx(sftp_tx_pkt_t *pktp, uint32_t need) if (pktp == NULL) return false; sftp_tx_pkt_t pkt = *pktp; - assert(pkt->sz >= pkt->used); size_t newsz; uint32_t oldsz; uint32_t oldused; @@ -255,6 +254,26 @@ grow_tx(sftp_tx_pkt_t *pktp, uint32_t need) pkt->sz = newsz; pkt->used = oldused; } + assert(pkt->sz >= pkt->used); + return true; +} + +bool +sftp_tx_pkt_reset(sftp_tx_pkt_t *pktp) +{ + assert(pktp); + if (pktp == NULL) + return false; + sftp_tx_pkt_t pkt = *pktp; + pkt->used = 0; + if (pkt->sz == SFTP_MIN_PACKET_ALLOC) + return true; + void *newbuf = realloc(pkt, SFTP_MIN_PACKET_ALLOC); + if (newbuf != NULL) { + *pktp = newbuf; + pkt = *pktp; + pkt->sz = SFTP_MIN_PACKET_ALLOC; + } return true; } @@ -297,6 +316,20 @@ sftp_appendstring(sftp_tx_pkt_t *pktp, sftp_str_t s) return true; } +bool +sftp_prep_tx_packet(sftp_tx_pkt_t pkt, uint8_t **buf, size_t *sz) +{ + assert(pkt); + assert(buf); + assert(sz); + if (pkt == NULL || buf == NULL || sz == NULL) + return false; + *sz = pkt->used + sizeof(pkt->used); + pkt->used = BE_INT32(pkt->used); + *buf = (uint8_t *)&pkt->used; + return true; +} + void sftp_free_tx_pkt(sftp_tx_pkt_t pkt) {