ceshi2 / demotool.py
zxcgqq's picture
Upload 8 files
b5e593e
from langchain.tools import BaseTool, StructuredTool, Tool, tool
from typing import Optional, Type
from langchain.callbacks.manager import AsyncCallbackManagerForToolRun, CallbackManagerForToolRun
import requests
import base64
import os
import uuid
from PIL import Image, ImageOps, ImageDraw, ImageFont
def optimizationProblem(query):
query = query +" What's the date today?"
return query
class CustomWeatherTool(BaseTool):
name = "weather"
description = "useful for when the input to this tool should be city"
def _run(self, query: str, run_manager: Optional[CallbackManagerForToolRun] = None) -> str:
return "The weather in "+query
async def _arun(self, query: str, run_manager: Optional[AsyncCallbackManagerForToolRun] = None) -> str:
"""Use the tool asynchronously."""
raise NotImplementedError("custom_search does not support async")
class Text2Image(BaseTool):
name = "Generate Image From User Input Text"
description ="useful when you want to generate an image from a user input text and save it to a file. like: generate an image of an object or something, or generate an image that includes some objects. The input to this tool should be a string, representing the text used to generate image. "
return_direct=True
def _run(self, query: str, run_manager: Optional[CallbackManagerForToolRun] = None) -> str:
url = "http://region-9.seetacloud.com:39487/sdapi/v1/txt2img"
body = {
"negative_prompt": "",
"width": "900",
"prompt": query,
"steps": "30",
"cfg_scale": "8",
"height": "900"
}
try:
result = requests.post(url, json=body, stream=True)
result.raise_for_status() # Raise an exception if request was not successful
response_data = result.json()
images_json = response_data["images"]
if len(images_json) > 0:
image_data = images_json[0].split(",", 1)[0]
image_bytes = base64.b64decode(image_data)
image_filename = os.path.join('image', f"{str(uuid.uuid4())[:8]}.png")
with open(image_filename, "wb") as file:
file.write(image_bytes)
except requests.exceptions.RequestException as e:
print("An error occurred:", e)
return image_filename
async def _arun(self, query: str, run_manager: Optional[AsyncCallbackManagerForToolRun] = None) -> str:
"""Use the tool asynchronously."""
raise NotImplementedError("custom_search does not support async")
@tool("search", return_direct=True)
def search_api(query: str) -> str:
"""Searches the API for the query."""
return f"Results for query {query}"