8000 FEAT: support sd inpainting models by qinxuye · Pull Request #1879 · xorbitsai/inference · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

FEAT: support sd inpainting models #1879

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jul 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/source/gen_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,7 @@ def get_unique_id(spec):
if not available_controlnet:
available_controlnet = None
model["available_controlnet"] = available_controlnet
model["model_ability"] = model.get("ability", "text-to-image")
rendered = env.get_template('image.rst.jinja').render(model)
output_file_path = os.path.join(output_dir, f"{model['model_name'].lower()}.rst")
with open(output_file_path, 'w') as output_file:
Expand Down
19 changes: 19 additions & 0 deletions doc/source/models/builtin/image/stable-diffusion-2-inpainting.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
.. _models_builtin_stable-diffusion-2-inpainting:

=============================
stable-diffusion-2-inpainting
=============================

- **Model Name:** stable-diffusion-2-inpainting
- **Model Family:** stable_diffusion
- **Abilities:** inpainting
- **Available ControlNet:** None

Specifications
^^^^^^^^^^^^^^

- **Model ID:** stabilityai/stable-diffusion-2-inpainting

Execute the following command to launch the model::

xinference launch --model-name stable-diffusion-2-inpainting --model-type image
19 changes: 19 additions & 0 deletions doc/source/models/builtin/image/stable-diffusion-inpainting.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
.. _models_builtin_stable-diffusion-inpainting:

===========================
stable-diffusion-inpainting
===========================

- **Model Name:** stable-diffusion-inpainting
- **Model Family:** stable_diffusion
- **Abilities:** inpainting
- **Available ControlNet:** None

Specifications
^^^^^^^^^^^^^^

- **Model ID:** runwayml/stable-diffusion-inpainting

Execute the following command to launch the model::

xinference launch --model-name stable-diffusion-inpainting --model-type image
2 changes: 1 addition & 1 deletion doc/templates/image.rst.jinja
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

- **Model Name:** {{ model_name }}
- **Model Family:** {{ model_family }}
- **Abilities:** text-to-image
- **Abilities:** {{ model_ability }}
- **Available ControlNet:** {{ available_controlnet }}

Specifications
Expand Down
65 changes: 65 additions & 0 deletions xinference/api/restful_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,6 +491,17 @@ def serve(self, logging_conf: Optional[dict] = None):
else None
),
)
self._router.add_api_route(
"/v1/images/inpainting",
self.create_inpainting,
methods=["POST"],
response_model=ImageList,
dependencies=(
[Security(self._auth_service, scopes=["models:read"])]
if self.is_authenticated()
else None
),
)
self._router.add_api_route(
"/v1/chat/completions",
self.create_chat_completion,
Expand Down Expand Up @@ -1410,6 +1421,60 @@ async def create_variations(
await self._report_error_event(model_uid, str(e))
raise HTTPException(status_code=500, detail=str(e))

async def create_inpainting(
self,
model: str = Form(...),
image: UploadFile = File(media_type="application/octet-stream"),
mask_image: UploadFile = File(media_type="application/octet-stream"),
prompt: Optional[Union[str, List[str]]] = Form(None),
negative_prompt: Optional[Union[str, List[str]]] = Form(None),
n: Optional[int] = Form(1),
response_format: Optional[str] = Form("url"),
size: Optional[str] = Form(None),
kwargs: Optional[str] = Form(None),
) -> Response:
model_uid = model
try:
model_ref = await (await self._get_supervisor_ref()).get_model(model_uid)
except ValueError as ve:
logger.error(str(ve), exc_info=True)
await self._report_error_event(model_uid, str(ve))
raise HTTPException(status_code=400, detail=str(ve))
except Exception as e:
logger.error(e, exc_info=True)
await self._report_error_event(model_uid, str(e))
raise HTTPException(status_code=500, detail=str(e))

try:
if kwargs is not None:
parsed_kwargs = json.loads(kwargs)
else:
parsed_kwargs = {}
im = Image.open(image.file)
mask_im = Image.open(mask_image.file)
if not size:
w, h = im.size
size = f"{w}*{h}"
image_list = await model_ref.inpainting(
image=im,
mask_image=mask_im,
prompt=prompt,
negative_prompt=negative_prompt,
n=n,
size=size,
response_format=response_format,
**parsed_kwargs,
)
return Response(content=image_list, media_type="application/json")
except RuntimeError as re:
logger.error(re, exc_info=True)
await self._report_error_event(model_uid, str(re))
raise HTTPException(status_code=400, detail=str(re))
except Exception as e:
logger.error(e, exc_info=True)
await self._report_error_event(model_uid, str(e))
raise HTTPException(status_code=500, detail=str(e))

async def create_flexible_infer(self, request: Request) -> Response:
payload = await request.json()

Expand Down
75 changes: 75 additions & 0 deletions xinference/client/restful/restful_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,81 @@ def image_to_image(
response_data = response.json()
return response_data

def inpainting(
self,
image: Union[str, bytes],
mask_image: Union[str, bytes],
prompt: str,
negative_prompt: Optional[str] = None,
n: int = 1,
size: Optional[str] = None,
response_format: str = "url",
**kwargs,
) -> "ImageList":
"""
Inpaint an image by the input text.

Parameters
----------
image: `Union[str, bytes]`
an image batch to be inpainted (which parts of the image to
be masked out with `mask_image` and repainted according to `prompt`). For both numpy array and pytorch
tensor, the expected value range is between `[0, 1]` If it's a tensor or a list or tensors, the
expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a list of arrays, the
expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image latents as `image`, but
if passing latents directly it is not encoded again.
mask_image: `Union[str, bytes]`
representing an image batch to mask `image`. White pixels in the mask
are repainted while black pixels are preserved. If `mask_image` is a PIL image, it is converted to a
single channel (luminance) before use. If it's a numpy array or pytorch tensor, it should contain one
color channel (L) instead of 3, so the expected shape for pytorch tensor would be `(B, 1, H, W)`, `(B,
H, W)`, `(1, H, W)`, `(H, W)`. And for numpy array would be for `(B, H, W, 1)`, `(B, H, W)`, `(H, W,
1)`, or `(H, W)`.
prompt: `str` or `List[str]`
The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
less than `1`).
n: `int`, defaults to 1
The number of images to generate per prompt. Must be between 1 and 10.
size: `str`, defaults to None
The width*height in pixels of the generated image.
response_format: `str`, defaults to `url`
The format in which the generated images are returned. Must be one of url or b64_json.
Returns
-------
ImageList
A list of image objects.
:param prompt:
:param image:
"""
url = f"{self._base_url}/v1/images/inpainting"
params = {
"model": self._model_uid,
"prompt": prompt,
"negative_prompt": negative_prompt,
"n": n,
"size": size,
"response_format": response_format,
"kwargs": json.dumps(kwargs),
}
files: List[Any] = []
for key, value in params.items():
files.append((key, (None, value)))
files.append(("image", ("image", image, "application/octet-stream")))
files.append(
("mask_image", ("mask_image", mask_image, "application/octet-stream"))
)
response = requests.post(url, files=files, headers=self.auth_headers)
if response.status_code != 200:
raise RuntimeError(
f"Failed to inpaint the images, detail: {_get_error_string(response)}"
)

response_data = response.json()
return response_data


class RESTfulGenerateModelHandle(RESTfulModelHandle):
def generate(
Expand Down
29 changes: 29 additions & 0 deletions xinference/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -697,6 +697,35 @@ async def image_to_image(
f"Model {self._model.model_spec} is not for creating image."
)

async def inpainting(
self,
image: "PIL.Image",
mask_image: "PIL.Image",
prompt: str,
negative_prompt: str,
n: int = 1,
size: str = "1024*1024",
response_format: str = "url",
*args,
**kwargs,
):
if hasattr(self._model, "inpainting"):
return await self._call_wrapper(
self._model.inpainting,
image,
mask_image,
prompt,
negative_prompt,
n,
size,
response_format,
*args,
**kwargs,
)
raise AttributeError(
f"Model {self._model.model_spec} is not for creating image."
)

@log_async(logger=logger)
@request_limit
async def infer(
Expand Down
3 changes: 3 additions & 0 deletions xinference/model/image/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ class ImageModelFamilyV1(CacheableModelSpec):
model_id: str
model_revision: str
model_hub: str = "huggingface"
ability: Optional[str]
controlnet: Optional[List["ImageModelFamilyV1"]]


Expand All @@ -71,6 +72,7 @@ def to_dict(self):
"model_name": self._model_spec.model_name,
"model_family": self._model_spec.model_family,
"model_revision": self._model_spec.model_revision,
"ability": self._model_spec.ability,
"controlnet": controlnet,
}

Expand Down Expand Up @@ -234,6 +236,7 @@ def create_image_model_instance(
lora_model_paths=lora_model,
lora_load_kwargs=lora_load_kwargs,
lora_fuse_kwargs=lora_fuse_kwargs,
ability=model_spec.ability,
**kwargs,
)
model_description = ImageModelDescription(
Expand Down
14 changes: 14 additions & 0 deletions xinference/model/image/model_spec.json
Original file line number Diff line number Diff line change
Expand Up @@ -92,5 +92,19 @@
"model_revision": "62134b9d8e703b5d6f74f1534457287a8bba77ef"
}
]
},
{
"model_name": "stable-diffusion-inpainting",
"model_family": "stable_diffusion",
"model_id": "runwayml/stable-diffusion-inpainting",
"model_revision": "51388a731f57604945fddd703ecb5c50e8e7b49d",
"ability": "inpainting"
},
{
"model_name": "stable-diffusion-2-inpainting",
"model_family": "stable_diffusion",
"model_id": "stabilityai/stable-diffusion-2-inpainting",
"model_revision": "81a84f49b15956b60b4272a405ad3daef3da4590",
"ability": "inpainting"
}
]
49 changes: 43 additions & 6 deletions xinference/model/image/stable_diffusion/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import logging
import os
import re
import sys
import time
import uuid
from concurrent.futures import ThreadPoolExecutor
Expand All @@ -39,6 +40,7 @@ def __init__(
lora_model: Optional[List[LoRA]] = None,
lora_load_kwargs: Optional[Dict] = None,
lora_fuse_kwargs: Optional[Dict] = None,
ability: Optional[str] = None,
**kwargs,
):
self._model_uid = model_uid
Expand All @@ -48,6 +50,7 @@ def __init__(
self._lora_model = lora_model
self._lora_load_kwargs = lora_load_kwargs or {}
self._lora_fuse_kwargs = lora_fuse_kwargs or {}
self._ability = ability
self._kwargs = kwargs

def _apply_lora(self):
Expand All @@ -64,8 +67,14 @@ def _apply_lora(self):
logger.info(f"Successfully loaded the LoRA for model {self._model_uid}.")

def load(self):
# import torch
from diffusers import AutoPipelineForText2Image
import torch

if self._ability in [None, "text2image", "image2image"]:
from diffusers import AutoPipelineForText2Image as AutoPipelineModel
elif self._ability == "inpainting":
from diffusers import AutoPipelineForInpainting as AutoPipelineModel
else:
raise ValueError(f"Unknown ability: {self._ability}")

controlnet = self._kwargs.get("controlnet")
if controlnet is not None:
Expand All @@ -74,12 +83,16 @@ def load(self):
logger.debug("Loading controlnet %s", controlnet)
self._kwargs["controlnet"] = ControlNetModel.from_pretrained(controlnet)

self._model = AutoPipelineForText2Image.from_pretrained(
torch_dtype = self._kwargs.get("torch_dtype")
if sys.platform != "darwin" and torch_dtype is None:
# The following params crashes on Mac M2
self._kwargs["torch_dtype"] = torch.float16
self._kwargs["use_safetensors"] = True

logger.debug("Loading model %s", AutoPipelineModel)
self._model = AutoPipelineModel.from_pretrained(
self._model_path,
**self._kwargs,
# The following params crashes on Mac M2
# torch_dtype=torch.float16,
# use_safetensors=True,
)
self._model = move_model_to_available_device(self._model)
# Recommended if your computer has < 64 GB of RAM
Expand Down Expand Up @@ -174,3 +187,27 @@ def image_to_image(
response_format=response_format,
**kwargs,
)

def inpainting(
self,
image: bytes,
mask_image: bytes,
prompt: Optional[Union[str, List[str]]] = None,
negative_prompt: Optional[Union[str, List[str]]] = None,
n: int = 1,
size: str = "1024*1024",
response_format: str = "url",
**kwargs,
):
width, height = map(int, re.split(r"[^\d]+", size))
return self._call_model(
image=image,
mask_image=mask_image,
prompt=prompt,
negative_prompt=negative_prompt,
height=height,
width=width,
num_images_per_prompt=n,
response_format=response_format,
**kwargs,
)
Loading
Loading
0