You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
64 lines
2.3 KiB
64 lines
2.3 KiB
3 weeks ago
|
import random
|
||
|
import requests
|
||
|
import io
|
||
|
import base64
|
||
|
from PIL import Image
|
||
|
|
||
|
def generate_cat_images(ip, port, model_params, prompt="a cat", num_images=5):
|
||
|
for i in range(num_images):
|
||
|
# Generate a random seed for each image
|
||
|
seed = random.randint(0, 4294967295)
|
||
|
|
||
|
# Build the payload for the txt2img API endpoint
|
||
|
payload = {
|
||
|
"prompt": prompt,
|
||
|
"negative_prompt": "", # Modify if needed
|
||
|
"sd_model_checkpoint": model_params["model"],
|
||
|
"steps": model_params["steps"],
|
||
|
"cfg_scale": model_params["cfg_scale"],
|
||
|
"width": model_params["width"],
|
||
|
"height": model_params["height"],
|
||
|
"sampler_index": model_params["sampler"],
|
||
|
"seed": seed
|
||
|
}
|
||
|
|
||
|
try:
|
||
|
# Include the custom port in the URL
|
||
|
response = requests.post(
|
||
|
url=f'http://{ip}:{port}/sdapi/v1/txt2img',
|
||
|
json=payload,
|
||
|
timeout=30
|
||
|
)
|
||
|
response.raise_for_status() # Check for HTTP errors
|
||
|
r = response.json()
|
||
|
|
||
|
# Process returned images (the API returns a list in r['images'])
|
||
|
for img_data in r.get('images', []):
|
||
|
# Remove any header if present (e.g., "data:image/png;base64,")
|
||
|
img_base64 = img_data.split(",", 1)[-1]
|
||
|
image = Image.open(io.BytesIO(base64.b64decode(img_base64)))
|
||
|
|
||
|
# Save the image with a unique filename
|
||
|
filename = f"cat_{i}.png"
|
||
|
image.save(filename)
|
||
|
print(f"Saved {filename}")
|
||
|
except requests.exceptions.Timeout:
|
||
|
print(f"Timeout occurred while generating image {i}")
|
||
|
except Exception as e:
|
||
|
print(f"An error occurred on iteration {i}: {e}")
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
ip = "172.30.200.3"
|
||
|
port = 35000 # Custom port for the API
|
||
|
# Define your model parameters (update 'model' to your actual checkpoint name)
|
||
|
model_params = {
|
||
|
"model": "sd3.5_checkpoint.ckpt", # Replace with your actual model checkpoint file
|
||
|
"steps": 20,
|
||
|
"cfg_scale": 7.0,
|
||
|
"width": 512,
|
||
|
"height": 512,
|
||
|
"sampler": "Euler" # Adjust based on your setup
|
||
|
}
|
||
|
|
||
|
generate_cat_images(ip, port, model_params, prompt="a cat", num_images=5)
|