-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
35 lines (31 loc) · 1.21 KB
/
main.py
File metadata and controls
35 lines (31 loc) · 1.21 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from generation_data import GenerationData
from generation_result import GenerationResult
from dotenv import load_dotenv
from replicate_api import init_replicate, text_to_image, generate_text, generate_image_style
import os
load_dotenv()
REPLICATE_KEY = os.environ.get("REPLICATE_KEY")
replicate_client = init_replicate(REPLICATE_KEY)
port = int(os.environ.get("PORT"))
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.post("/generate")
async def generate(data: GenerationData) -> GenerationResult:
improvted_prompt = generate_text(client=replicate_client, image_prompt=data.prompt)
style = generate_image_style(client=replicate_client, image_prompt=improvted_prompt)
improvted_prompt = improvted_prompt + ", " + style
data.prompt = improvted_prompt
result_image = text_to_image(client=replicate_client, data=data)
result = GenerationResult(image=result_image, prompt=data.prompt)
return result
if __name__ == "__main__":
import uvicorn
uvicorn.run("main:app",host='0.0.0.0', port=port, reload=True)