add first-in-first-out queue for image prompts, and add pagination
This commit is contained in:
@@ -8,6 +8,8 @@ defmodule Diffuser.Generator.PromptRequest do
|
||||
schema "prompt_requests" do
|
||||
field :prompt, :string
|
||||
field :status, :string, default: "queued"
|
||||
field :steps, :integer
|
||||
field :guidance_scale, :float
|
||||
|
||||
has_many :images, PromptRequestResult, on_delete: :delete_all
|
||||
|
||||
@@ -17,7 +19,7 @@ defmodule Diffuser.Generator.PromptRequest do
|
||||
@doc false
|
||||
def changeset(prompt_request, attrs) do
|
||||
prompt_request
|
||||
|> cast(attrs, [:prompt, :status])
|
||||
|> cast(attrs, [:prompt, :status, :steps, :guidance_scale])
|
||||
|> validate_required([:prompt])
|
||||
end
|
||||
end
|
||||
|
@@ -1,98 +0,0 @@
|
||||
defmodule Diffuser.Generator.PromptRequestGenserver do
|
||||
use GenServer
|
||||
alias Diffuser.Generator
|
||||
alias Diffuser.Generator.PromptRequest
|
||||
alias DiffuserWeb.Endpoint
|
||||
alias Diffuser.PythonHelper, as: Helper
|
||||
|
||||
@path 'lib/diffuser/python'
|
||||
|
||||
def new(%{prompt_request: %PromptRequest{} = prompt_request}) do
|
||||
GenServer.start_link(
|
||||
__MODULE__,
|
||||
%{prompt_request: prompt_request},
|
||||
name: name_for(prompt_request)
|
||||
)
|
||||
end
|
||||
|
||||
def name_for(%PromptRequest{id: prompt_request_id}),
|
||||
do: {:global, "prompt_request:#{prompt_request_id}"}
|
||||
|
||||
def init(%{prompt_request: %PromptRequest{} = prompt_request}) do
|
||||
send(self(), :start_prompt)
|
||||
|
||||
{:ok,
|
||||
%{
|
||||
prompt_request: prompt_request
|
||||
}}
|
||||
end
|
||||
|
||||
def handle_info(:start_prompt, %{prompt_request: prompt_request} = state) do
|
||||
with {:ok, %{prompt: prompt} = active_prompt} <-
|
||||
update_and_broadcast_progress(prompt_request, "in_progress"),
|
||||
:ok <- call_python(:test_script, :test_func, prompt),
|
||||
%PromptRequest{} = prompt_request_with_results <- write_and_save_images(active_prompt),
|
||||
{:ok, completed_prompt} <-
|
||||
update_and_broadcast_progress(prompt_request_with_results, "finished") do
|
||||
IO.inspect(completed_prompt)
|
||||
{:noreply, state}
|
||||
else
|
||||
nil ->
|
||||
raise("prompt not found")
|
||||
|
||||
{:error, message} ->
|
||||
raise(message)
|
||||
end
|
||||
end
|
||||
|
||||
defp update_and_broadcast_progress(%PromptRequest{id: id} = prompt_request, new_status) do
|
||||
{:ok, new_prompt} = Generator.update_prompt_request(prompt_request, %{status: new_status})
|
||||
:ok = Endpoint.broadcast("request:#{id}", "request", %{prompt_request: new_prompt})
|
||||
|
||||
{:ok, new_prompt}
|
||||
end
|
||||
|
||||
defp call_python(_module, _func, prompt) do
|
||||
Port.open(
|
||||
{:spawn, "python #{@path}/stable_diffusion.py --prompt #{prompt}"},
|
||||
[:binary, {:packet, 4}]
|
||||
)
|
||||
|
||||
# TODO: We will want to flush, and get the image data from the script
|
||||
# then write it to PromptResult
|
||||
|
||||
# pid = Helper.py_instance(Path.absname(@path))
|
||||
# :python.call(pid, module, func, args)
|
||||
|
||||
# pid
|
||||
# |> :python.stop()
|
||||
|
||||
:ok
|
||||
end
|
||||
|
||||
defp write_and_save_images(%PromptRequest{id: id, prompt: prompt}) do
|
||||
height = :rand.uniform(512)
|
||||
width = :rand.uniform(512)
|
||||
IO.inspect(height)
|
||||
|
||||
{:ok, resp} =
|
||||
:httpc.request(
|
||||
:get,
|
||||
{'http://placekitten.com/#{height}/#{width}', []},
|
||||
[],
|
||||
body_format: :binary
|
||||
)
|
||||
|
||||
{{_, 200, 'OK'}, _headers, body} = resp
|
||||
|
||||
Generator.create_prompt_request_results(id, [
|
||||
%{
|
||||
file_name: "#{prompt}.jpg",
|
||||
filename: "#{prompt}.jpg",
|
||||
binary: body
|
||||
}
|
||||
])
|
||||
|
||||
Generator.get_prompt_request!(id)
|
||||
end
|
||||
end
|
50
lib/diffuser/generator/prompt_request_queue.ex
Normal file
50
lib/diffuser/generator/prompt_request_queue.ex
Normal file
@@ -0,0 +1,50 @@
|
||||
defmodule Diffuser.Generator.PromptRequestQueue do
|
||||
use GenServer
|
||||
alias Diffuser.Generator.PromptRequestWorker
|
||||
|
||||
### GenServer API
|
||||
|
||||
@doc """
|
||||
GenServer.init/1 callback
|
||||
"""
|
||||
def init(state) do
|
||||
{:ok, state}
|
||||
end
|
||||
|
||||
@doc """
|
||||
GenServer.handle_call/3 callback
|
||||
"""
|
||||
def handle_call(:dequeue, _from, [value | state]) do
|
||||
{:reply, value, state}
|
||||
end
|
||||
|
||||
def handle_call(:dequeue, _from, []), do: {:reply, nil, []}
|
||||
|
||||
def handle_call(:queue, _from, state), do: {:reply, state, state}
|
||||
|
||||
@doc """
|
||||
GenServer.handle_cast/2 callback
|
||||
"""
|
||||
def handle_cast({:enqueue, value}, state) do
|
||||
{:noreply, state, {:continue, {:enqueue, value}}}
|
||||
end
|
||||
|
||||
def handle_continue({:enqueue, value}, state) when length(state) > 0 do
|
||||
{:continue, {:enqueue, value}, state}
|
||||
end
|
||||
|
||||
def handle_continue({:enqueue, value}, state) do
|
||||
PromptRequestWorker.start(value)
|
||||
{:noreply, state}
|
||||
end
|
||||
|
||||
### Client API / Helper functions
|
||||
|
||||
def start_link(state \\ []) do
|
||||
GenServer.start_link(__MODULE__, state, name: __MODULE__)
|
||||
end
|
||||
|
||||
def queue, do: GenServer.call(__MODULE__, :queue)
|
||||
def enqueue(value), do: GenServer.cast(__MODULE__, {:enqueue, value})
|
||||
def dequeue, do: GenServer.call(__MODULE__, :dequeue)
|
||||
end
|
@@ -1,27 +0,0 @@
|
||||
defmodule Diffuser.Generator.PromptRequestSupervisor do
|
||||
use DynamicSupervisor
|
||||
alias Diffuser.Generator.PromptRequest
|
||||
|
||||
def start_link(init_arg) do
|
||||
DynamicSupervisor.start_link(__MODULE__, init_arg, name: __MODULE__)
|
||||
end
|
||||
|
||||
@impl true
|
||||
def init(_init_arg) do
|
||||
DynamicSupervisor.init(strategy: :one_for_one)
|
||||
end
|
||||
|
||||
def start_prompt_request(%PromptRequest{} = prompt_request) do
|
||||
Task.Supervisor.start_child(
|
||||
__MODULE__,
|
||||
Diffuser.Generator.PromptRequestGenserver,
|
||||
:new,
|
||||
[
|
||||
%{
|
||||
prompt_request: prompt_request
|
||||
}
|
||||
],
|
||||
restart: :transient
|
||||
)
|
||||
end
|
||||
end
|
85
lib/diffuser/generator/prompt_request_worker.ex
Normal file
85
lib/diffuser/generator/prompt_request_worker.ex
Normal file
@@ -0,0 +1,85 @@
|
||||
defmodule Diffuser.Generator.PromptRequestWorker do
|
||||
alias Diffuser.Generator
|
||||
alias Diffuser.Generator.PromptRequest
|
||||
alias DiffuserWeb.Endpoint
|
||||
alias Diffuser.Repo
|
||||
|
||||
@path [:code.priv_dir(:diffuser), "python"] |> Path.join()
|
||||
@steps 10
|
||||
@guidance_scale 7.5
|
||||
|
||||
def start(%PromptRequest{} = prompt_request) do
|
||||
with {:ok, active_prompt} <-
|
||||
update_and_broadcast_progress(prompt_request, "in_progress"),
|
||||
{:ok, _file_location} <- call_python(:test_script, :test_func, active_prompt),
|
||||
%PromptRequest{} = prompt_request_with_results <-
|
||||
write_and_save_images(active_prompt),
|
||||
{:ok, completed_prompt} <-
|
||||
update_and_broadcast_progress(prompt_request_with_results, "finished") do
|
||||
{:ok, completed_prompt |> Repo.preload(:images)}
|
||||
else
|
||||
nil ->
|
||||
raise("prompt not found")
|
||||
|
||||
{:error, message} ->
|
||||
raise(message)
|
||||
end
|
||||
end
|
||||
|
||||
defp update_and_broadcast_progress(%PromptRequest{id: id} = prompt_request, new_status) do
|
||||
{:ok, new_prompt} =
|
||||
Generator.update_prompt_request(prompt_request, %{
|
||||
status: new_status,
|
||||
steps: @steps,
|
||||
guidance_scale: @guidance_scale
|
||||
})
|
||||
|
||||
:ok = Endpoint.broadcast("request:#{id}", "request", %{prompt_request: new_prompt})
|
||||
|
||||
{:ok, new_prompt}
|
||||
end
|
||||
|
||||
defp call_python(_module, _func, %{id: prompt_id, prompt: prompt}) do
|
||||
port =
|
||||
Port.open(
|
||||
{:spawn,
|
||||
~s(python #{@path}/stable_diffusion.py --prompt "#{prompt}" --output "#{@path}/#{prompt_id}.png" --num-inference-steps #{@steps})},
|
||||
[:binary]
|
||||
)
|
||||
|
||||
python_loop(port, prompt_id)
|
||||
end
|
||||
|
||||
defp python_loop(port, prompt_id) do
|
||||
receive do
|
||||
{^port, {:data, ":finished" <> msg}} ->
|
||||
{:ok, msg}
|
||||
|
||||
{^port, {:data, ":step" <> step}} ->
|
||||
Endpoint.broadcast("request:#{prompt_id}", "progress", step)
|
||||
python_loop(port, prompt_id)
|
||||
|
||||
{^port, result} ->
|
||||
IO.inspect(result, label: "RESULT")
|
||||
python_loop(port, prompt_id)
|
||||
end
|
||||
end
|
||||
|
||||
defp write_and_save_images(%PromptRequest{id: id}) do
|
||||
file_path = "#{@path}/#{id}.png"
|
||||
|
||||
with {:ok, body} <- File.read(file_path),
|
||||
{:ok, _result} <-
|
||||
Generator.create_prompt_request_result(
|
||||
id,
|
||||
%{
|
||||
file_name: "#{id}.png",
|
||||
filename: "#{id}.png",
|
||||
binary: body
|
||||
}
|
||||
),
|
||||
:ok <- File.rm(file_path) do
|
||||
Generator.get_prompt_request!(id)
|
||||
end
|
||||
end
|
||||
end
|
Reference in New Issue
Block a user