Skip to content

Commit 55ff064

Browse files
committed
feat(multipart): allow reusing types other than stream iterators
1 parent e3ae1a5 commit 55ff064

File tree

3 files changed

+255
-73
lines changed

3 files changed

+255
-73
lines changed

src/client/body/multipart.rs

Lines changed: 164 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -12,26 +12,61 @@ use crate::{client::body::PyStream, error::Error, header::HeaderMap};
1212

1313
/// A multipart form for a request.
1414
#[pyclass(subclass)]
15-
pub struct Multipart(pub Option<multipart::Form>);
15+
pub struct Multipart {
16+
pub form: Option<multipart::Form>,
17+
pub parts: Vec<Part>,
18+
}
19+
20+
/// The data for a part value of a multipart form.
21+
#[derive(FromPyObject)]
22+
pub enum Value {
23+
Text(PyBackedStr),
24+
Bytes(PyBackedBytes),
25+
File(PathBuf),
26+
Stream(PyStream),
27+
}
28+
29+
/// A part of a multipart form.
30+
#[pyclass(subclass)]
31+
pub struct Part {
32+
pub name: String,
33+
pub value: Option<Value>,
34+
pub filename: Option<String>,
35+
pub mime: Option<String>,
36+
pub length: Option<u64>,
37+
pub headers: Option<HeaderMap>,
38+
}
39+
40+
// ===== impl Multipart =====
1641

1742
#[pymethods]
1843
impl Multipart {
19-
/// Creates a new multipart form.
44+
/// Creates a new multipart.
2045
#[new]
2146
#[pyo3(signature = (*parts))]
22-
pub fn new(parts: &Bound<PyTuple>) -> PyResult<Multipart> {
23-
let mut form = multipart::Form::new();
47+
pub fn new(py: Python, parts: &Bound<PyTuple>) -> PyResult<Multipart> {
48+
let mut new_parts = Vec::with_capacity(parts.len());
2449
for part in parts {
2550
let part = part.cast::<Part>()?;
2651
let mut part = part.borrow_mut();
27-
form = part
28-
.name
29-
.take()
30-
.zip(part.inner.take())
31-
.map(|(name, inner)| form.part(name, inner))
32-
.ok_or_else(|| Error::Memory)?;
52+
new_parts.push(part.try_clone(py)?);
53+
}
54+
55+
Ok(Self {
56+
form: None,
57+
parts: new_parts,
58+
})
59+
}
60+
}
61+
62+
impl Multipart {
63+
fn build_form(&mut self, py: Python) -> PyResult<multipart::Form> {
64+
let mut form = multipart::Form::new();
65+
for part in &mut self.parts {
66+
let (name, inner) = part.build_form_part(py)?;
67+
form = form.part(name, inner);
3368
}
34-
Ok(Multipart(Some(form)))
69+
Ok(form)
3570
}
3671
}
3772

@@ -40,31 +75,120 @@ impl FromPyObject<'_, '_> for Multipart {
4075

4176
fn extract(ob: Borrowed<PyAny>) -> PyResult<Self> {
4277
let multipart = ob.cast::<Multipart>()?;
43-
multipart
44-
.borrow_mut()
45-
.0
46-
.take()
47-
.map(Some)
48-
.map(Self)
49-
.ok_or_else(|| Error::Memory)
50-
.map_err(Into::into)
78+
let mut multipart = multipart.borrow_mut();
79+
let form = multipart.build_form(ob.py())?;
80+
81+
Ok(Multipart {
82+
form: Some(form),
83+
parts: Vec::new(),
84+
})
5185
}
5286
}
5387

54-
/// A part of a multipart form.
55-
#[pyclass(subclass)]
56-
pub struct Part {
57-
pub name: Option<String>,
58-
pub inner: Option<multipart::Part>,
88+
// ===== impl Value =====
89+
90+
impl Value {
91+
fn try_clone(&self, py: Python) -> Option<Self> {
92+
match self {
93+
Value::Text(text) => {
94+
let text = text.clone_ref(py);
95+
Some(Value::Text(text))
96+
}
97+
Value::Bytes(bytes) => {
98+
let bytes = bytes.clone_ref(py);
99+
Some(Value::Bytes(bytes))
100+
}
101+
Value::File(path) => {
102+
let path = path.clone();
103+
Some(Value::File(path))
104+
}
105+
Value::Stream(_) => None,
106+
}
107+
}
59108
}
60109

61-
/// The data for a part value of a multipart form.
62-
#[derive(FromPyObject)]
63-
pub enum Value {
64-
Text(PyBackedStr),
65-
Bytes(PyBackedBytes),
66-
File(PathBuf),
67-
Stream(PyStream),
110+
// ===== impl Part =====
111+
112+
impl Part {
113+
fn with_value(&self, value: Value) -> Part {
114+
Part {
115+
name: self.name.clone(),
116+
value: Some(value),
117+
filename: self.filename.clone(),
118+
mime: self.mime.clone(),
119+
length: self.length,
120+
headers: self.headers.clone(),
121+
}
122+
}
123+
124+
fn build_inner(value: Value, length: Option<u64>) -> Result<multipart::Part, Error> {
125+
Ok(match value {
126+
Value::Text(text) => multipart::Part::stream(Body::from(Bytes::from_owner(text))),
127+
Value::Bytes(bytes) => multipart::Part::stream(Body::from(Bytes::from_owner(bytes))),
128+
Value::File(path) => pyo3_async_runtimes::tokio::get_runtime()
129+
.block_on(multipart::Part::file(path))
130+
.map_err(Error::from)?,
131+
Value::Stream(stream) => {
132+
let stream = Body::wrap_stream(stream);
133+
match length {
134+
Some(length) => multipart::Part::stream_with_length(stream, length),
135+
None => multipart::Part::stream(stream),
136+
}
137+
}
138+
})
139+
}
140+
141+
fn clone_value_or_take(&mut self, py: Python) -> PyResult<Value> {
142+
self.value
143+
.as_ref()
144+
.and_then(|value| value.try_clone(py))
145+
.or_else(|| self.value.take())
146+
.ok_or_else(|| Error::Memory.into())
147+
}
148+
149+
fn build_form_part(&mut self, py: Python) -> PyResult<(String, multipart::Part)> {
150+
let value = self.clone_value_or_take(py)?;
151+
let name = self.name.clone();
152+
let filename = self.filename.clone();
153+
let mime = self.mime.clone();
154+
let length = self.length;
155+
let headers = self.headers.clone();
156+
157+
py.detach(move || {
158+
let mut inner = Self::build_inner(value, length)?;
159+
160+
if let Some(filename) = filename {
161+
inner = inner.file_name(filename);
162+
}
163+
164+
if let Some(mime) = mime {
165+
inner = inner.mime_str(&mime).map_err(Error::Library)?;
166+
}
167+
168+
if let Some(headers) = headers {
169+
inner = inner.headers(headers.0);
170+
}
171+
172+
Ok((name, inner))
173+
})
174+
}
175+
176+
fn try_clone(&mut self, py: Python) -> PyResult<Part> {
177+
if let Some(part) = self
178+
.value
179+
.as_ref()
180+
.and_then(|value| value.try_clone(py))
181+
.map(|value| self.with_value(value))
182+
{
183+
return Ok(part);
184+
}
185+
186+
self.value
187+
.take()
188+
.map(|value| self.with_value(value))
189+
.ok_or_else(|| Error::Memory)
190+
.map_err(Into::into)
191+
}
68192
}
69193

70194
#[pymethods]
@@ -80,52 +204,20 @@ impl Part {
80204
headers = None
81205
))]
82206
pub fn new(
83-
py: Python,
84207
name: String,
85208
value: Value,
86209
filename: Option<String>,
87210
mime: Option<&str>,
88211
length: Option<u64>,
89212
headers: Option<HeaderMap>,
90-
) -> PyResult<Part> {
91-
py.detach(|| {
92-
// Create the inner part
93-
let mut inner = match value {
94-
Value::Text(text) => multipart::Part::stream(Body::from(Bytes::from_owner(text))),
95-
Value::Bytes(bytes) => {
96-
multipart::Part::stream(Body::from(Bytes::from_owner(bytes)))
97-
}
98-
Value::File(path) => pyo3_async_runtimes::tokio::get_runtime()
99-
.block_on(multipart::Part::file(path))
100-
.map_err(Error::from)?,
101-
Value::Stream(stream) => {
102-
let stream = Body::wrap_stream(stream);
103-
match length {
104-
Some(length) => multipart::Part::stream_with_length(stream, length),
105-
None => multipart::Part::stream(stream),
106-
}
107-
}
108-
};
109-
110-
// Set the filename and MIME type if provided
111-
if let Some(filename) = filename {
112-
inner = inner.file_name(filename);
113-
}
114-
115-
// Set the MIME type if provided
116-
if let Some(mime) = mime {
117-
inner = inner.mime_str(mime).map_err(Error::Library)?;
118-
}
119-
120-
// Set the headers if provided
121-
if let Some(headers) = headers {
122-
inner = inner.headers(headers.0);
123-
}
124-
125-
Ok(Part {
126-
name: Some(name),
127-
inner: Some(inner),
128-
})
129-
})
213+
) -> Part {
214+
Part {
215+
name,
216+
value: Some(value),
217+
filename,
218+
mime: mime.map(ToOwned::to_owned),
219+
length,
220+
headers,
221+
}
130222
}
131223
}

src/client/req.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -396,7 +396,7 @@ where
396396
apply_option!(
397397
set_if_some,
398398
builder,
399-
request.multipart.and_then(|form| form.0),
399+
request.multipart.and_then(|form| form.form),
400400
multipart
401401
);
402402
apply_option!(

tests/multipart_test.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
from pathlib import Path
2+
3+
import pytest
4+
import wreq
5+
from wreq import Multipart, Part
6+
7+
client = wreq.Client(tls_info=True)
8+
9+
10+
def assert_form_value(data, key, expected):
11+
value = data["form"][key]
12+
if isinstance(value, list):
13+
assert expected in value
14+
else:
15+
assert value == expected
16+
17+
18+
@pytest.mark.asyncio
19+
@pytest.mark.flaky(reruns=3, reruns_delay=2)
20+
async def test_reuse_multipart_with_clonable_parts():
21+
form = Multipart(
22+
Part(name="a", value="1"),
23+
Part(name="b", value=b"2"),
24+
Part(name="c", value=Path("./README.md"), filename="README.md", mime="text/plain"),
25+
)
26+
27+
for _ in range(3):
28+
resp = await client.post("https://httpbin.io/post", multipart=form)
29+
async with resp:
30+
assert resp.status.is_success()
31+
data = await resp.json()
32+
assert_form_value(data, "a", "1")
33+
assert_form_value(data, "b", "2")
34+
assert "c" in data["files"]
35+
36+
37+
@pytest.mark.asyncio
38+
@pytest.mark.flaky(reruns=3, reruns_delay=2)
39+
async def test_stream_part_is_one_shot_when_reusing_multipart():
40+
def file_stream(path):
41+
with open(path, "rb") as f:
42+
while chunk := f.read(1024):
43+
yield chunk
44+
45+
form = Multipart(
46+
Part(
47+
name="stream",
48+
value=file_stream("./README.md"),
49+
filename="README.md",
50+
mime="text/plain",
51+
),
52+
)
53+
54+
resp = await client.post("https://httpbin.io/post", multipart=form)
55+
async with resp:
56+
assert resp.status.is_success()
57+
58+
with pytest.raises(RuntimeError):
59+
resp = await client.post("https://httpbin.io/post", multipart=form)
60+
async with resp:
61+
pass
62+
63+
64+
@pytest.mark.asyncio
65+
@pytest.mark.flaky(reruns=3, reruns_delay=2)
66+
async def test_reuse_same_part_without_copy_for_clonable_value():
67+
part = Part(name="a", value="1")
68+
69+
form1 = Multipart(part)
70+
form2 = Multipart(part)
71+
72+
for form in (form1, form2):
73+
resp = await client.post("https://httpbin.io/post", multipart=form)
74+
async with resp:
75+
assert resp.status.is_success()
76+
data = await resp.json()
77+
assert_form_value(data, "a", "1")
78+
79+
80+
@pytest.mark.asyncio
81+
@pytest.mark.flaky(reruns=3, reruns_delay=2)
82+
async def test_reuse_same_part_without_copy_fails_for_stream_value():
83+
def bytes_stream():
84+
yield b"hello"
85+
86+
part = Part(name="stream", value=bytes_stream())
87+
Multipart(part)
88+
89+
with pytest.raises(RuntimeError):
90+
Multipart(part)

0 commit comments

Comments
 (0)