ivxivx's picture
chore: use gradio instead of streamlit
cb75d9f unverified
raw
history blame
3.41 kB
import os
from huggingface_hub import login
login(token=os.getenv("HUGGINGFACEHUB_API_KEY"))
from transformers import AutoModelForCausalLM, AutoTokenizer
import gradio as gr
import torch
def get_device_type() -> str:
if torch.cuda.is_available():
return "cuda"
elif torch.backends.mps.is_available():
return "mps"
else:
return "cpu"
# # HuggingFaceTB/SmolLM2-135M-Instruct
# model_name="deepseek-ai/DeepSeek-R1-Distill-Qwen-7B" # 15G
# "meta-llama/Llama-3.2-3B does not work
model_name="meta-llama/Llama-3.2-3B-Instruct" # 6.5G
device = get_device_type()
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name).to(device)
system_prompt = """
You are a customer support officer. Extract the transaction ID from the USER INPUT and determine its type.
Extraction rules:
- Look for words starting with 'payout' or 'payment'.
- The next character must be a dash ('-').
- There must be at least one digit or character after the dash.
- The transaction ID must appear exactly in the USER INPUT.
- If found, set found to true; otherwise, set found to false.
Type rules:
- If the transaction ID starts with 'payout', type is payout.
- If it starts with 'payment', type is payment.
===>USER INPUT BEGINS
{input}
<===USER INPUT ENDS
Respond in valid JSON with these fields:
found: (boolean) Whether a valid transaction ID was found.
transaction_id: (string, if found) The extracted transaction ID.
transaction_type: (string, if found) The transaction type in lowercase.
justification: (string) Explain how you determined the transaction ID and type. If not found, do not fabricate an explanation.
Return only valid JSON and nothing else.
"""
examples = [
"My transaction payment-a1c1 failed",
"Why is my withdrawal payout-b2c2 pending for 3 days",
"There is an issue with my transaction payout-87l2k3",
"I am having trouble with my transaction",
]
def predict(message, history):
# Always inject the user message into the system prompt's {input} placeholder
sys_prompt = system_prompt.replace("{input}", message)
if not history or history[0].get("role") != "system":
history = [{"role": "system", "content": sys_prompt}] + history
else:
history[0]["content"] = sys_prompt
history.append({"role": "user", "content": message})
# 1. Build prompt from history using chat template
prompt = tokenizer.apply_chat_template(history, tokenize=False)
# 2. Tokenize prompt for model input
inputs = tokenizer(prompt, return_tensors="pt").to(device)
# 3. Generate response
outputs = model.generate(**inputs, max_new_tokens=100)
decoded = tokenizer.decode(outputs[0], skip_special_tokens=True)
# print(f"Response: {response}, outputs: {outputs}")
# Extract only the assistant's message (after the last user message)
# This works for most chat templates that append the assistant's reply at the end
if "<|im_start|>assistant" in decoded:
response = decoded.split("<|im_start|>assistant")[-1]
# Remove possible end tokens or markers
response = response.replace("<|im_end|>", "").strip()
else:
# Fallback: just return the decoded output
response = decoded.strip()
return response
demo = gr.ChatInterface(predict, type="messages", examples=examples)
demo.launch(share=True)