Spaces:
Sleeping
Sleeping
from huggingface_hub import InferenceClient | |
import gradio as gr | |
import random | |
API_URL = "https://api-inference.huggingface.co/models/" | |
client = InferenceClient( | |
"mistralai/Mistral-7B-Instruct-v0.1" | |
) | |
def format_prompt(message, history): | |
prompt = """Your name is BDFD Script Coder. Here are all functions: $addButton[ ] | |
$addCmdReactions[ ] | |
$addEmoji[ ] | |
$addField[ ] | |
$addReactions[ ] | |
$addSelectMenuOption[ ] | |
$addTextInput[ ] | |
$addTimestamp | |
$addTimestamp[ ] | |
$allMembersCount | |
$allowMention | |
$allowRoleMentions[ ] | |
$allowUserMentions[ ] | |
$alternativeParsing | |
$and[ ] | |
$appendOptionSuggestion[ ] | |
$argCount[ ] | |
$argsCheck[ ] | |
$async[ ] | |
$authorAvatar | |
$authorID | |
$authorIcon[ ] | |
$authorOfMessage[ ] | |
$authorURL[ ] | |
$author[ ] | |
$autoCompleteOptionName | |
$autoCompleteOptionValue | |
$awaitFunc[ ] | |
$awaitReactions[ ] | |
$await[ ] | |
$ban | |
$banID | |
$banID[ ] | |
$ban[ ] | |
$blackListIDs[ ] | |
$blackListRolesIDs[ ] | |
$blackListRoles[ ] | |
$blackListServers[ ] | |
$blackListUsers[ ] | |
$boostCount | |
$c[ ] | |
$calculate[ ] | |
$catch | |
$categoryCount | |
$categoryCount[ ] | |
$categoryID[ ] | |
$changeCooldownTime[ ] | |
$changeUsernameWithID[ ] | |
$changeUsername[ ] | |
$channelCount | |
$channelExists[ ] | |
$channelID | |
$channelIDFromName[ ] | |
$channelID[ ] | |
$channelName[ ] | |
$channelPosition | |
$channelPosition[ ] | |
$channelSendMessage[ ] | |
$channelTopic | |
$channelTopic[ ] | |
$channelType[ ] | |
$charCount[ ] | |
$checkCondition[ ] | |
$checkContains[ ] | |
$checkUserPerms[ ] | |
$clear | |
$clearReactions[ ] | |
$clear[ ] | |
$closeTicket[ ] | |
$colorRole[ ] | |
$color[ ] | |
$commandsCount | |
$cooldown[ ] | |
$createChannel[ ] | |
$createRole[ ] | |
$creationDate[ ] | |
$cropText[ ] | |
$customEmoji[ ] | |
$customID | |
$customImage[ ] | |
$date | |
$day | |
$defer | |
$deleteChannelsByName[ ] | |
$deleteChannels[ ] | |
$deleteIn[ ] | |
$deleteMessage[ ] | |
$deleteRole[ ] | |
$deletecommand | |
$description[ ] | |
$disableInnerSpaceRemoval | |
$disableSpecialEscaping | |
$discriminator[ ] | |
$divide[ ] | |
$dm | |
$dmChannelID[ ] | |
$dm[ ] | |
$editButton[ ] | |
$editChannelPerms[ ] | |
$editEmbedIn[ ] | |
$editIn[ ] | |
$editMessage[ ] | |
$editSelectMenuOption[ ] | |
$editSelectMenu[ ] | |
$editSplitText[ ] | |
$editThread[ ] | |
$else | |
$elseif[ ] | |
$embedSuppressErrors[ ] | |
$embeddedURL[ ] | |
$emoteCount | |
$enableDecimals[ ] | |
$enabled[ ] | |
$endasync | |
$endif | |
$endtry | |
$ephemeral | |
$error[ ] | |
$eval[ ] | |
$executionTime | |
$findChannel[ ] | |
$findRole[ ] | |
$findUser[ ] | |
$footerIcon[ ] | |
$footer[ ] | |
$getBanReason[ ] | |
$getBotInvite | |
$getChannelVar[ ] | |
$getCooldown[ ] | |
$getCustomStatus[ ] | |
$getEmbedData[ ] | |
$getInviteInfo[ ] | |
$getLeaderboardValue[ ] | |
$getMessage[ ] | |
$getReactions[ ] | |
$getRoleColor[ ] | |
$getServerInvite | |
$getServerInvite[ ] | |
$getServerVar[ ] | |
$getTextSplitIndex[ ] | |
$getTextSplitLength | |
$getTimestamp | |
$getTimestamp[ ] | |
$getUserStatus[ ] | |
$getUserVar[ ] | |
$getVar[ ] | |
$giveRole[ ] | |
$globalCooldown[ ] | |
$globalUserLeaderboard[ ] | |
$guildExists[ ] | |
$guildID | |
$guildID[ ] | |
$hasRole[ ] | |
$highestRole | |
$highestRoleWithPerms[ ] | |
$highestRole[ ] | |
$hostingExpireTime | |
$hostingExpireTime[ ] | |
$hour | |
$httpAddHeader[ ] | |
$httpDelete[ ] | |
$httpGetHeader[ ] | |
$httpGet[ ] | |
$httpPatch[ ] | |
$httpPost[ ] | |
$httpPut[ ] | |
$httpRemoveHeader[ ] | |
$httpResult | |
$httpResult[ ] | |
$httpStatus | |
$hypesquad[ ] | |
$if[ ] | |
$ignoreChannels[ ] | |
$ignoreLinks | |
$ignoreTriggerCase | |
$image[ ] | |
$input[ ] | |
$isAdmin[ ] | |
$isBanned[ ] | |
$isBoolean[ ] | |
$isBot[ ] | |
$isHoisted[ ] | |
$isMentionable[ ] | |
$isNSFW[ ] | |
$isNumber[ ] | |
$isSlash | |
$isTimedOut[ ] | |
$isUserDMEnabled[ ] | |
$isValidHex[ ] | |
$joinSplitText[ ] | |
$jsonArrayAppend[ ] | |
$jsonArrayCount[ ] | |
$jsonArray[ ] | |
$jsonClear | |
$jsonExists[ ] | |
$jsonParse[ ] | |
$jsonPretty[ ] | |
$jsonSet[ ] | |
$jsonStringify | |
$jsonUnset[ ] | |
$json[ ] | |
$kick | |
$kickMention | |
$kickMention[ ] | |
$kick[ ] | |
$lowestRole | |
$lowestRoleWithPerms[ ] | |
$lowestRole[ ] | |
$max[ ] | |
$membersCount | |
$membersCount[ ] | |
$mentionedChannels[ ] | |
$mentionedRoles[ ] | |
$mentioned[ ] | |
$message | |
$messageID | |
$message[ ] | |
$min[ ] | |
$minute | |
$modifyChannelPerms[ ] | |
$modifyChannel[ ] | |
$modifyRolePerms[ ] | |
$modifyRole[ ] | |
$modulo[ ] | |
$month | |
$multi[ ] | |
$mute[ ] | |
$newModal[ ] | |
$newSelectMenu[ ] | |
$newTicket[ ] | |
$nickname | |
$nickname[ ] | |
$noMentionMessage | |
$noMentionMessage[ ] | |
$nomention | |
$numberSeparator[ ] | |
$onlyAdmin[ ] | |
$onlyBotChannelPerms[ ] | |
$onlyBotPerms[ ] | |
$onlyForCategories[ ] | |
$onlyForChannels[ ] | |
$onlyForIDs[ ] | |
$onlyForRoleIDs[ ] | |
$onlyForRoles[ ] | |
$onlyForServers[ ] | |
$onlyForUsers[ ] | |
$onlyIfMessageContains[ ] | |
$onlyIf[ ] | |
$onlyNSFW[ ] | |
$onlyPerms[ ] | |
$optOff[ ] | |
$or[ ] | |
$parentID | |
$parentID[ ] | |
$pinMessage | |
$pinMessage[ ] | |
$ping | |
$premiumExpireTime | |
$premiumExpireTime[ ] | |
$publishMessage[ ] | |
$random | |
$randomChannelID | |
$randomMention | |
$randomString[ ] | |
$randomText[ ] | |
$randomUser | |
$randomUserID | |
$random[ ] | |
$registerGuildCommands | |
$registerGuildCommands[ ] | |
$removeButtons | |
$removeButtons[ ] | |
$removeComponent[ ] | |
$removeContains[ ] | |
$removeLinks | |
$removeLinks[ ] | |
$removeSplitTextElement[ ] | |
$repeatMessage[ ] | |
$replaceText[ ] | |
$repliedMessageID | |
$repliedMessageID[ ] | |
$reply | |
$replyIn[ ] | |
$reply[ ] | |
$resetChannelVar[ ] | |
$resetServerVar[ ] | |
$resetUserVar[ ] | |
$roleCount | |
$roleExists[ ] | |
$roleGrant[ ] | |
$roleID[ ] | |
$roleInfo[ ] | |
$roleName[ ] | |
$roleNames | |
$rolePosition[ ] | |
$round[ ] | |
$scriptLanguage | |
$second | |
$sendEmbedMessage[ ] | |
$sendMessage[ ] | |
$sendNotification[ ] | |
$serverChannelExists[ ] | |
$serverCooldown[ ] | |
$serverCount | |
$serverDescription | |
$serverDescription[ ] | |
$serverEmojis[ ] | |
$serverIcon | |
$serverIcon[ ] | |
$serverInfo[ ] | |
$serverLeaderboard[ ] | |
$serverName[ ] | |
$serverNames | |
$serverNames[ ] | |
$serverOwner | |
$serverOwner[ ] | |
$serverRegion | |
$serverVerificationLvl | |
$setChannelVar[ ] | |
$setServerVar[ ] | |
$setUserVar[ ] | |
$setVar[ ] | |
$shardID | |
$shardID[ ] | |
$slashCommandsCount | |
$slashID | |
$slashID[ ] | |
$slowmode[ ] | |
$sort[ ] | |
$splitText[ ] | |
$startThread[ ] | |
$stop | |
$sub[ ] | |
$sum[ ] | |
$suppressErrors | |
$suppressErrors[ ] | |
$takeRole[ ] | |
$textSplit[ ] | |
$threadAddMember[ ] | |
$threadRemoveMember[ ] | |
$thumbnail[ ] | |
$time[ ] | |
$timeout[ ] | |
$title[ ] | |
$toLowercase[ ] | |
$toTitleCase[ ] | |
$toUppercase[ ] | |
$trimContent | |
$trimSpace[ ] | |
$try | |
$tts | |
$unban | |
$unbanID | |
$unbanID[ ] | |
$unescape[ ] | |
$unmute[ ] | |
$unpinMessage[ ] | |
$unregisterGuildCommands | |
$unregisterGuildCommands[ ] | |
$untimeout[ ] | |
$uptime | |
$url[ ] | |
$useChannel[ ] | |
$usedEmoji | |
$userAvatar[ ] | |
$userExists[ ] | |
$userID[ ] | |
$userInfo[ ] | |
$userJoinedDiscord[ ] | |
$userJoined[ ] | |
$userLeaderboard[ ] | |
$userPerms[ ] | |
$userReacted[ ] | |
$userRoles[ ] | |
$userServerAvatar[ ] | |
$username | |
$username[ ] | |
$varExistError[ ] | |
$varExists[ ] | |
$var[ ] | |
$variablesCount[ ] | |
$webhookAvatarURL[ ] | |
$webhookColor[ ] | |
$webhookContent[ ] | |
$webhookCreate[ ] | |
$webhookDelete[ ] | |
$webhookDescription[ ] | |
$webhookFooter[ ] | |
$webhookSend[ ] | |
$webhookTitle[ ] | |
$webhookUsername[ ]""" | |
for user_prompt, bot_response in history: | |
prompt += f"[INST] {user_prompt} [/INST]" | |
prompt += f" {bot_response}</s> " | |
prompt += f"[INST] {message} [/INST]" | |
return prompt | |
def generate(prompt, history, temperature=0.9, max_new_tokens=2048, top_p=0.95, repetition_penalty=1.0): | |
temperature = float(temperature) | |
if temperature < 1e-2: | |
temperature = 1e-2 | |
top_p = float(top_p) | |
generate_kwargs = dict( | |
temperature=temperature, | |
max_new_tokens=max_new_tokens, | |
top_p=top_p, | |
repetition_penalty=repetition_penalty, | |
do_sample=True, | |
seed=random.randint(0, 10**7), | |
) | |
formatted_prompt = format_prompt(prompt, history) | |
stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False) | |
output = "" | |
for response in stream: | |
output += response.token.text | |
yield output | |
return output | |
additional_inputs=[ | |
gr.Slider( | |
label="Temperature", | |
value=0.9, | |
minimum=0.0, | |
maximum=1.0, | |
step=0.05, | |
interactive=True, | |
info="Higher values produce more diverse outputs", | |
), | |
gr.Slider( | |
label="Max new tokens", | |
value=2048, | |
minimum=64, | |
maximum=4096, | |
step=64, | |
interactive=True, | |
info="The maximum numbers of new tokens", | |
), | |
gr.Slider( | |
label="Top-p (nucleus sampling)", | |
value=0.90, | |
minimum=0.0, | |
maximum=1, | |
step=0.05, | |
interactive=True, | |
info="Higher values sample more low-probability tokens", | |
), | |
gr.Slider( | |
label="Repetition penalty", | |
value=1.2, | |
minimum=1.0, | |
maximum=2.0, | |
step=0.05, | |
interactive=True, | |
info="Penalize repeated tokens", | |
) | |
] | |
customCSS = """ | |
#component-7 { # this is the default element ID of the chat component | |
height: 1600px; # adjust the height as needed | |
flex-grow: 4; | |
} | |
""" | |
with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
gr.ChatInterface( | |
generate, | |
additional_inputs=additional_inputs, | |
) | |
demo.queue().launch(debug=True) |