File size: 3,410 Bytes
822c123
 
 
 
ee304fe
822c123
 
 
42337a3
 
 
 
 
 
 
 
 
 
 
822c123
798e4a0
822c123
f3c5d0a
 
822c123
 
42337a3
822c123
 
 
f3c5d0a
eab7e53
f3c5d0a
eab7e53
 
 
 
 
 
f3c5d0a
eab7e53
 
 
f3c5d0a
 
 
 
 
eab7e53
f3c5d0a
 
 
eab7e53
f3c5d0a
 
822c123
 
 
 
 
 
 
 
 
f9e55e7
 
b1f07de
f9e55e7
 
 
f3c5d0a
822c123
16541ad
b1f07de
 
 
 
 
16541ad
cb75d9f
16541ad
a19606f
cb75d9f
 
 
 
 
 
 
 
 
 
 
822c123
 
 
 
f727d69
822c123
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100

import os

from huggingface_hub import login
login(token=os.getenv("HUGGINGFACEHUB_API_KEY"))

from transformers import AutoModelForCausalLM, AutoTokenizer
import gradio as gr
import torch

def get_device_type() -> str:
    if torch.cuda.is_available():
        return "cuda"

    elif torch.backends.mps.is_available():
        return "mps"

    else:
        return "cpu"

# # HuggingFaceTB/SmolLM2-135M-Instruct
# model_name="deepseek-ai/DeepSeek-R1-Distill-Qwen-7B" # 15G

# "meta-llama/Llama-3.2-3B does not work
model_name="meta-llama/Llama-3.2-3B-Instruct"   # 6.5G

device = get_device_type()
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name).to(device)

system_prompt = """
You are a customer support officer. Extract the transaction ID from the USER INPUT and determine its type.

Extraction rules:
- Look for words starting with 'payout' or 'payment'.
- The next character must be a dash ('-').
- There must be at least one digit or character after the dash.
- The transaction ID must appear exactly in the USER INPUT.
- If found, set found to true; otherwise, set found to false.

Type rules:
- If the transaction ID starts with 'payout', type is payout.
- If it starts with 'payment', type is payment.

===>USER INPUT BEGINS
{input}
<===USER INPUT ENDS

Respond in valid JSON with these fields:
found: (boolean) Whether a valid transaction ID was found.
transaction_id: (string, if found) The extracted transaction ID.
transaction_type: (string, if found) The transaction type in lowercase.
justification: (string) Explain how you determined the transaction ID and type. If not found, do not fabricate an explanation.
Return only valid JSON and nothing else.
"""

examples = [
    "My transaction payment-a1c1 failed", 
    "Why is my withdrawal payout-b2c2 pending for 3 days", 
    "There is an issue with my transaction payout-87l2k3",
    "I am having trouble with my transaction",
]

def predict(message, history):
    # Always inject the user message into the system prompt's {input} placeholder
    sys_prompt = system_prompt.replace("{input}", message)
    if not history or history[0].get("role") != "system":
        history = [{"role": "system", "content": sys_prompt}] + history
    else:
        history[0]["content"] = sys_prompt

    history.append({"role": "user", "content": message})

    # 1. Build prompt from history using chat template
    prompt = tokenizer.apply_chat_template(history, tokenize=False)
    # 2. Tokenize prompt for model input
    inputs = tokenizer(prompt, return_tensors="pt").to(device)
    # 3. Generate response
    outputs = model.generate(**inputs, max_new_tokens=100)
    decoded = tokenizer.decode(outputs[0], skip_special_tokens=True)

    # print(f"Response: {response}, outputs: {outputs}")
    
    # Extract only the assistant's message (after the last user message)
    # This works for most chat templates that append the assistant's reply at the end
    if "<|im_start|>assistant" in decoded:
        response = decoded.split("<|im_start|>assistant")[-1]
        # Remove possible end tokens or markers
        response = response.replace("<|im_end|>", "").strip()
    else:
        # Fallback: just return the decoded output
        response = decoded.strip()

    return response

demo = gr.ChatInterface(predict, type="messages", examples=examples)

demo.launch(share=True)